├── PlaneSAM ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── loss.cpython-39.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── box_ops.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── eval_tools.cpython-39.pyc │ │ ├── make_prompt.cpython-39.pyc │ │ ├── postprocess.cpython-39.pyc │ │ └── visual_tools.cpython-39.pyc │ ├── postprocess.py │ ├── visual_tools.py │ ├── loss.py │ ├── box_ops.py │ ├── train_tools.py │ ├── eval_tools.py │ └── make_prompt.py ├── model │ ├── __pycache__ │ │ ├── mlp.cpython-37.pyc │ │ ├── mlp.cpython-39.pyc │ │ ├── mlp.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── efficient_sam.cpython-37.pyc │ │ ├── efficient_sam.cpython-39.pyc │ │ ├── efficient_sam.cpython-310.pyc │ │ ├── build_efficient_sam.cpython-310.pyc │ │ ├── build_efficient_sam.cpython-37.pyc │ │ ├── build_efficient_sam.cpython-39.pyc │ │ ├── two_way_transformer.cpython-310.pyc │ │ ├── two_way_transformer.cpython-37.pyc │ │ ├── two_way_transformer.cpython-39.pyc │ │ ├── efficient_sam_decoder.cpython-310.pyc │ │ ├── efficient_sam_decoder.cpython-37.pyc │ │ ├── efficient_sam_decoder.cpython-39.pyc │ │ ├── efficient_sam_encoder.cpython-310.pyc │ │ ├── efficient_sam_encoder.cpython-37.pyc │ │ └── efficient_sam_encoder.cpython-39.pyc │ ├── __init__.py │ ├── build_efficient_sam.py │ ├── mlp.py │ ├── two_way_transformer.py │ ├── efficient_sam_encoder.py │ ├── efficient_sam_decoder.py │ └── efficient_sam.py ├── FasterRCNN │ ├── __init__.py │ ├── __pycache__ │ │ ├── boxes.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── boxes.cpython-310.pyc │ │ ├── roi_head.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── det_utils.cpython-310.pyc │ │ ├── det_utils.cpython-39.pyc │ │ ├── image_list.cpython-39.pyc │ │ ├── roi_head.cpython-310.pyc │ │ ├── transform.cpython-310.pyc │ │ ├── transform.cpython-39.pyc │ │ ├── image_list.cpython-310.pyc │ │ ├── rpn_function.cpython-310.pyc │ │ ├── rpn_function.cpython-39.pyc │ │ ├── build_FasterRCNN.cpython-39.pyc │ │ ├── faster_rcnn_framework.cpython-310.pyc │ │ └── faster_rcnn_framework.cpython-39.pyc │ ├── image_list.py │ ├── build_FasterRCNN.py │ ├── boxes.py │ ├── transform.py │ ├── det_utils.py │ └── faster_rcnn_framework.py ├── __pycache__ │ ├── Nyuv2Dataset.cpython-39.pyc │ ├── PlaneDataset.cpython-37.pyc │ ├── PlaneDataset.cpython-39.pyc │ ├── S2D3DSDataset.cpython-39.pyc │ ├── RecordReaderAll.cpython-36.pyc │ └── RecordReaderAll.cpython-39.pyc ├── backbone │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── vgg_model.cpython-310.pyc │ │ ├── vgg_model.cpython-39.pyc │ │ ├── mobilenetv2_model.cpython-39.pyc │ │ ├── mobilenetv2_model.cpython-310.pyc │ │ ├── resnet101_fpn_model.cpython-39.pyc │ │ ├── resnet50_fpn_model.cpython-310.pyc │ │ ├── resnet50_fpn_model.cpython-39.pyc │ │ ├── feature_pyramid_network.cpython-310.pyc │ │ └── feature_pyramid_network.cpython-39.pyc │ ├── __init__.py │ ├── vgg_model.py │ ├── mobilenetv2_model.py │ ├── resnet101_fpn_model.py │ └── feature_pyramid_network.py ├── requirements.txt ├── README.md ├── S2D3DSDataset.py ├── Nyuv2Dataset.py ├── eval.py ├── PlaneDataset.py ├── train.py └── visual.py └── README.md /PlaneSAM/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/mlp.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/mlp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/mlp.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__init__.py: -------------------------------------------------------------------------------- 1 | from .faster_rcnn_framework import FasterRCNN, FastRCNNPredictor 2 | from .rpn_function import AnchorsGenerator 3 | -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/mlp.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/Nyuv2Dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/Nyuv2Dataset.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/PlaneDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/PlaneDataset.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/PlaneDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/PlaneDataset.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/S2D3DSDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/S2D3DSDataset.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/box_ops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/box_ops.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/RecordReaderAll.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/RecordReaderAll.cpython-36.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/RecordReaderAll.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/RecordReaderAll.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/eval_tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/eval_tools.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/vgg_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/vgg_model.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/vgg_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/vgg_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/make_prompt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/make_prompt.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/postprocess.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/postprocess.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/visual_tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/visual_tools.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/transform.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/transform.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/build_efficient_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/build_efficient_sam.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/build_efficient_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/build_efficient_sam.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/build_efficient_sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/build_efficient_sam.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/two_way_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/two_way_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/two_way_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/two_way_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/two_way_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/two_way_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/build_FasterRCNN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/build_FasterRCNN.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/resnet101_fpn_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/resnet101_fpn_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | from .build_efficient_sam import ( 5 | build_efficient_sam_vitt, 6 | build_efficient_sam_vits, 7 | ) 8 | -------------------------------------------------------------------------------- /PlaneSAM/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet101_fpn_model import resnet101_fpn_backbone 2 | from .mobilenetv2_model import MobileNetV2 3 | from .vgg_model import vgg 4 | from .feature_pyramid_network import LastLevelMaxPool, BackboneWithFPN 5 | -------------------------------------------------------------------------------- /PlaneSAM/model/build_efficient_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .efficient_sam import build_efficient_sam 8 | 9 | def build_efficient_sam_vitt(): 10 | return build_efficient_sam( 11 | encoder_patch_embed_dim=192, 12 | encoder_num_heads=3, 13 | checkpoint=None, 14 | ) 15 | 16 | 17 | def build_efficient_sam_vits(): 18 | return build_efficient_sam( 19 | encoder_patch_embed_dim=384, 20 | encoder_num_heads=6, 21 | checkpoint=None, 22 | ) 23 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/image_list.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from torch import Tensor 3 | 4 | 5 | class ImageList(object): 6 | """ 7 | Structure that holds a list of images (of possibly 8 | varying sizes) as a single tensor. 9 | This works by padding the images to the same size, 10 | and storing in a field the original sizes of each image 11 | """ 12 | 13 | def __init__(self, tensors, image_sizes): 14 | # type: (Tensor, List[Tuple[int, int]]) -> None 15 | """ 16 | Arguments: 17 | tensors (tensor) padding后的图像数据 18 | image_sizes (list[tuple[int, int]]) padding前的图像尺寸 19 | """ 20 | self.tensors = tensors 21 | self.image_sizes = image_sizes 22 | 23 | def to(self, device): 24 | # type: (Device) -> ImageList # noqa 25 | cast_tensor = self.tensors.to(device) 26 | return ImageList(cast_tensor, self.image_sizes) 27 | 28 | -------------------------------------------------------------------------------- /PlaneSAM/model/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from torch import nn 4 | 5 | 6 | # Lightly adapted from 7 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 8 | class MLPBlock(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | hidden_dim: int, 13 | output_dim: int, 14 | num_layers: int, 15 | act: Type[nn.Module], 16 | ) -> None: 17 | super().__init__() 18 | self.num_layers = num_layers 19 | h = [hidden_dim] * (num_layers - 1) 20 | self.layers = nn.ModuleList( 21 | nn.Sequential(nn.Linear(n, k), act()) 22 | for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) 23 | ) 24 | self.fc = nn.Linear(hidden_dim, output_dim) 25 | 26 | def forward(self, x): 27 | for layer in self.layers: 28 | x = layer(x) 29 | return self.fc(x) 30 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/build_FasterRCNN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from FasterRCNN import FasterRCNN, FastRCNNPredictor 4 | from backbone import resnet101_fpn_backbone 5 | 6 | def create_model(num_classes, load_pretrain_weights=False): 7 | backbone = resnet101_fpn_backbone(pretrain_path="", 8 | norm_layer=torch.nn.BatchNorm2d, 9 | trainable_layers=5) 10 | # 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数 11 | model = FasterRCNN(backbone=backbone, num_classes=91) 12 | 13 | if load_pretrain_weights: 14 | weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth", map_location='cpu') 15 | missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False) 16 | if len(missing_keys) != 0 or len(unexpected_keys) != 0: 17 | print("missing_keys: ", missing_keys) 18 | print("unexpected_keys: ", unexpected_keys) 19 | 20 | # get number of input features for the classifier 21 | in_features = model.roi_heads.box_predictor.cls_score.in_features 22 | # replace the pre-trained head with a new one 23 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 24 | 25 | return model -------------------------------------------------------------------------------- /PlaneSAM/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | addict==2.4.0 3 | certifi==2024.7.4 4 | charset-normalizer==3.3.2 5 | cloudpickle==3.0.0 6 | contourpy==1.2.1 7 | cycler==0.12.1 8 | Cython==3.0.11 9 | filelock==3.15.4 10 | fonttools==4.53.1 11 | fsspec==2024.6.1 12 | fvcore==0.1.5.post20221221 13 | grpcio==1.65.1 14 | huggingface-hub==0.24.5 15 | idna==3.7 16 | imageio==2.34.2 17 | importlib_metadata==8.2.0 18 | importlib_resources==6.4.0 19 | iopath==0.1.10 20 | joblib==1.4.2 21 | kiwisolver==1.4.5 22 | lazy_loader==0.4 23 | Markdown==3.6 24 | MarkupSafe==2.1.5 25 | matplotlib==3.9.1 26 | MultiScaleDeformableAttention==1.0 27 | networkx==3.2.1 28 | numpy==1.26.4 29 | opencv-python==4.10.0.84 30 | packaging==24.1 31 | pillow==10.4.0 32 | platformdirs==4.2.2 33 | portalocker==3.1.1 34 | protobuf==4.25.4 35 | pycocotools==2.0.8 36 | pyparsing==3.1.2 37 | python-dateutil==2.9.0.post0 38 | PyYAML==6.0.2 39 | requests==2.32.3 40 | safetensors==0.4.4 41 | scikit-image==0.24.0 42 | scikit-learn==1.5.1 43 | scipy==1.13.1 44 | six==1.16.0 45 | submitit==1.5.1 46 | tabulate==0.9.0 47 | tensorboard==2.17.0 48 | tensorboard-data-server==0.7.2 49 | termcolor==2.4.0 50 | threadpoolctl==3.5.0 51 | tifffile==2024.7.24 52 | timm==1.0.8 53 | tomli==2.0.1 54 | torch==1.12.1+cu116 55 | torchvision==0.13.1+cu116 56 | tqdm==4.66.4 57 | typing_extensions==4.12.2 58 | urllib3==2.2.2 59 | Werkzeug==3.0.3 60 | yacs==0.1.8 61 | yapf==0.40.2 62 | zipp==3.19.2 63 | -------------------------------------------------------------------------------- /PlaneSAM/utils/postprocess.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import utils.box_ops as box_ops 5 | 6 | class PostProcess(nn.Module): 7 | """ This module converts the model's output into the format expected by the coco api""" 8 | @torch.no_grad() 9 | def forward(self, outputs, target_sizes): 10 | """ Perform the computation 11 | Parameters: 12 | outputs: raw outputs of the model 13 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 14 | For evaluation, this must be the original image size (before any data augmentation) 15 | For visualization, this should be the image size after data augment, but before padding 16 | """ 17 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 18 | 19 | assert len(out_logits) == len(target_sizes) 20 | assert target_sizes.shape[1] == 2 21 | 22 | prob = F.softmax(out_logits, -1) 23 | scores, labels = prob[..., :-1].max(-1) 24 | 25 | # convert to [x0, y0, x1, y1] format 26 | boxes = box_ops.box_cxcywh_to_xywh(out_bbox) 27 | # and from relative [0, 1] to absolute [0, height] coordinates 28 | img_h, img_w = target_sizes.unbind(1) 29 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 30 | boxes = boxes * scale_fct[:, None, :] 31 | 32 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] 33 | 34 | return results -------------------------------------------------------------------------------- /PlaneSAM/README.md: -------------------------------------------------------------------------------- 1 | # Multimodal plane instance segmentation with the Segment Anything Model 2 | 3 | 4 | 5 | ## Getting Start 6 | Build the Pytorch Environment: 7 | ```bash 8 | conda create -n PlaneSAM python=3.9.16 9 | conda activate PlaneSAM 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Data Preparation 14 | 15 | We train and test our network using the same plane dataset as [PlaneTR](https://github.com/IceTTTb/PlaneTR3D). 16 | You can access the dataset from [here](https://pan.baidu.com/s/1pyx-Ou3SLq7XG5NIMqC2cQ?pwd=in3b) 17 | 18 | ## Training 19 | 20 | Our training process consists of two steps:
21 | - First, we pretrain on a large-scale RGB-D dataset. The pretrained weights can be obtained from [here](https://pan.baidu.com/s/1NarX09MpkDDsBr7WWI0mzw?pwd=pvkj) and placed in the weights directory.
22 | - Second, load the pretrained weights into the network and run the train.py script.The trained weights can be obtained from [here](https://pan.baidu.com/s/1O4zygzKL13obNMAB2kuxkg?pwd=nrwj).
23 | 24 | ## Evaluation 25 | 26 | During the evaluation, we use Faster R-CNN as the plane detector. The trained weights can be obtained from [here](https://pan.baidu.com/s/1uO1pqs2B4R5IPKQgU0fPTg?pwd=26jr) and placed in the weights directory.The unseen test dataset can be obtained from [here](https://pan.baidu.com/s/1ywNjTRCzXfuxb2VHPGTGzg?pwd=9qcm).
27 | To evaluate the plane segmentation capabilities of PlaneSAM, please run the eval.py script. 28 | 29 | ## Acknowledgements 30 | This code is based on the [EfficientSAM](https://github.com/yformer/EfficientSAM) repository. We would like to acknowledge the authors for their work. 31 | 32 | ## Additional Note 33 | 34 | Due to the author's current busy schedule, we apologize for the possibly poor code quality. Optimizations will be made in the future. If you encounter any questions or bugs in the code, feel free to ask. 35 | -------------------------------------------------------------------------------- /PlaneSAM/utils/visual_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | colors = np.array([ 5 | [255, 0, 0], 6 | [0, 255, 0], 7 | [0, 0, 255], 8 | [255, 255, 0], 9 | [255, 0, 255], 10 | [0, 255, 255], 11 | [255, 128, 0], 12 | [128, 0, 255], 13 | [0, 128, 255], 14 | [255, 0, 128], 15 | [128, 255, 0], 16 | [0, 255, 128], 17 | [255, 128, 128], 18 | [128, 255, 128], 19 | [128, 128, 255], 20 | [255, 255, 128], 21 | [255, 128, 255], 22 | [128, 255, 255], 23 | [192, 192, 192], 24 | [128, 0, 0], 25 | [0, 128, 0], 26 | [0, 0, 128], 27 | [128, 128, 0], 28 | [128, 0, 128], 29 | [0, 128, 128], 30 | [64, 0, 0], 31 | [0, 64, 0], 32 | [0, 0, 64], 33 | [64, 64, 0], 34 | [64, 0, 64], 35 | [0,64, 64], 36 | [0, 0, 0] 37 | ], dtype=np.uint8) 38 | 39 | def map_masks_to_colors(masks): 40 | """ 41 | :param masks: [num_planes, H, W] 42 | :return: rgb_image[H, W, 3] 43 | """ 44 | num_masks = len(masks) 45 | # 随机生成类别对应的颜色 46 | 47 | # 创建空白RGB图像 48 | height, width = masks[0].shape 49 | rgb_image = np.zeros((height, width, 3), dtype=np.uint8) 50 | 51 | for i, mask in enumerate(masks): 52 | color = colors[i] 53 | # mask_rgb = np.zeros((height, width, 3), dtype=np.uint8) 54 | # mask_rgb[mask > 0] = color 55 | # rgb_image += mask_rgb 56 | rgb_image[mask > 0] = color 57 | 58 | return rgb_image 59 | 60 | 61 | def draw_bounding_boxes(image, boxes): 62 | """ 63 | 在图像上绘制边界框 (使用 OpenCV)。 64 | 65 | :param image: RGB图像 (ndarray, shape [H, W, 3])。 66 | :param boxes: 边界框数组 (ndarray, shape [B, 4], 每行 [xmin, ymin, xmax, ymax])。 67 | :return: 带有边界框的图像。 68 | """ 69 | image_copy = image.copy() 70 | for box in boxes: 71 | xmin, ymin, xmax, ymax = map(int, box) 72 | cv2.rectangle(image_copy, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=2) # 绿色边框 73 | return image_copy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal plane instance segmentation with the Segment Anything Model 2 | 3 | This is the official PyTorch implementation for our paper "Multimodal plane instance segmentation with the Segment Anything Model". This paper has been accepted for publication in Automation in Construction. 4 | 5 | You may also learn about our algorithm from the preprint version of our paper on arXiv, titled “PlaneSAM: Multimodal Plane Instance Segmentation Using the Segment Anything Model” (https://arxiv.org/abs/2410.16545 ). 6 | 7 | However, the version we published in the journal Automation in Construction is more formal and provides a more complete description of the algorithm. 8 | 9 | 10 | ## 🔭 Introduction 11 | Abstract: Plane instance segmentation from RGB-D data is critical for BIM-related tasks. However, existing deep-learning methods rely on only RGB bands, overlooking depth information. To address this, PlaneSAM, a Segment-Anything-Model-based network, is proposed. It fully integrates RGB-D bands using a dual-complexity backbone: a simple branch primarily for the D band and a high-capacity branch mainly for RGB bands. This structure facilitates effective D-band learning with limited data, preserves EfficientSAM's RGB feature representations, and enables task-specific fine-tuning. To improve adaptability to RGB-D domains, a self-supervised pretraining strategy is introduced. EfficientSAM’s loss is also optimized for large-plane segmentation. Additionally, plane detection is performed using Faster R-CNN, enabling fully automatic segmentation. State-of-the-art performance is achieved on multiple datasets, with <10% additional overhead compared to EfficientSAM. The proposed dual-complexity backbone shows strong potential for transferring RGB-based foundation models to RGB+X domains in other scenarios, while the pretraining strategy is promising for other data-scarce tasks. 12 | 13 | ## 🔭 Citation 14 | If you find our work useful for your research, please consider citing our paper. 15 | Deng, Z., Yang, Z., Chen, C., Zeng, C., Meng, Y., Yang, B., 2025. Multimodal plane instance segmentation with the Segment Anything Model. Automation in Construction 180, 106541. 16 | 17 |

18 | -------------------------------------------------------------------------------- /PlaneSAM/backbone/vgg_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class VGG(nn.Module): 6 | def __init__(self, features, class_num=1000, init_weights=False, weights_path=None): 7 | super(VGG, self).__init__() 8 | self.features = features 9 | self.classifier = nn.Sequential( 10 | nn.Linear(512*7*7, 4096), 11 | nn.ReLU(True), 12 | nn.Dropout(p=0.5), 13 | nn.Linear(4096, 4096), 14 | nn.ReLU(True), 15 | nn.Dropout(p=0.5), 16 | nn.Linear(4096, class_num) 17 | ) 18 | if init_weights and weights_path is None: 19 | self._initialize_weights() 20 | 21 | if weights_path is not None: 22 | self.load_state_dict(torch.load(weights_path)) 23 | 24 | def forward(self, x): 25 | # N x 3 x 224 x 224 26 | x = self.features(x) 27 | # N x 512 x 7 x 7 28 | x = torch.flatten(x, start_dim=1) 29 | # N x 512*7*7 30 | x = self.classifier(x) 31 | return x 32 | 33 | def _initialize_weights(self): 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | nn.init.xavier_uniform_(m.weight) 38 | if m.bias is not None: 39 | nn.init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.Linear): 41 | nn.init.xavier_uniform_(m.weight) 42 | # nn.init.normal_(m.weight, 0, 0.01) 43 | nn.init.constant_(m.bias, 0) 44 | 45 | 46 | def make_features(cfg: list): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == "M": 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | layers += [conv2d, nn.ReLU(True)] 55 | in_channels = v 56 | return nn.Sequential(*layers) 57 | 58 | 59 | cfgs = { 60 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 61 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 62 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 63 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 64 | } 65 | 66 | 67 | def vgg(model_name="vgg16", weights_path=None): 68 | assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name) 69 | cfg = cfgs[model_name] 70 | 71 | model = VGG(make_features(cfg), weights_path=weights_path) 72 | return model 73 | -------------------------------------------------------------------------------- /PlaneSAM/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import utils.box_ops as box_ops 5 | 6 | 7 | # 输入[B, H, W]每一个像素是一个整数表示一个类 8 | # 输出[B, C, H, W]共C个类,每一个类对应一个二值掩码 9 | def build_target(target, num_classes=2, ignore_index=-100): 10 | dice_target = target.clone() 11 | if ignore_index >= 0: 12 | ignore_mask = torch.eq(target, ignore_index) 13 | dice_target[ignore_mask] = 0 14 | # [B, H, W] -> [B, H, W, C] 15 | dice_target = nn.functional.one_hot(dice_target, num_classes).float() 16 | dice_target[ignore_mask] = ignore_index 17 | else: 18 | dice_target = nn.functional.one_hot(dice_target, num_classes).float() 19 | 20 | return dice_target.permute(0, 3, 1, 2) 21 | 22 | 23 | # 输入[B, H, W] 24 | def dice_coeff(x, target, ignore_index=-100, epsilon=1e-6): 25 | d = 0. 26 | batch_size = x.shape[0] 27 | for i in range(batch_size): 28 | x_i = x[i].reshape(-1) 29 | t_i = target[i].reshape(-1) 30 | if ignore_index >= 0: 31 | # 找出像素值不为ignore_index的位置 32 | roi_mask = torch.ne(t_i, ignore_index) 33 | x_i = x_i[roi_mask] 34 | t_i = t_i[roi_mask] 35 | inter = torch.dot(x_i, t_i) 36 | sets_sum = torch.sum(x_i) + torch.sum(t_i) 37 | # 如果sets_sum为0,说明预测和实际值都为0,预测百分百正确,dice系数为1 38 | if sets_sum == 0: 39 | sets_sum = 2 * inter 40 | 41 | d += (2 * inter + epsilon) / (sets_sum + epsilon) 42 | 43 | return d / batch_size 44 | 45 | 46 | def multiclass_dice_coeff(x, target, ignore_index=-100, epsilon=1e-6): 47 | dice = 0. 48 | for channel in range(x.shape[1]): 49 | dice += dice_coeff(x[:, channel, ...], target[:, channel, ...], ignore_index, epsilon) 50 | 51 | 52 | def dice_loss(x, target, multiclass=False, ignore_index=-100): 53 | x = torch.sigmoid(x) 54 | fn = multiclass_dice_coeff if multiclass else dice_coeff 55 | return 1 - fn(x, target, ignore_index=ignore_index) 56 | 57 | 58 | def criterion(x_mask, x_iou, target, alpha=1, gamma=2): 59 | """ 60 | :param x_mask: [B, H, W]SAM输出的掩码 61 | :param x_iou: [B,]掩码对应的预测iou 62 | :param target: [B, H, W]标签 63 | :param alpha: focalloss参数 64 | :param gamma: focalloss参数 65 | :return: 返回综合损失 66 | """ 67 | batch_size, h, w = x_mask.shape 68 | binary_mask = (x_mask > 0).float() 69 | 70 | ce_loss = F.binary_cross_entropy_with_logits(x_mask, target, reduction='mean') 71 | pt = torch.exp(-ce_loss) 72 | focal_loss = alpha * (1 - pt) ** gamma * ce_loss 73 | 74 | Dice_loss = dice_loss(x_mask, target) 75 | 76 | binary_mask = binary_mask.reshape(batch_size, -1) 77 | target = target.reshape(batch_size, -1) 78 | intersection = torch.sum(torch.logical_and(binary_mask, target).float(), dim=1) 79 | union = torch.sum(torch.logical_or(binary_mask, target).float(), dim=1) 80 | iou = intersection / (union + 1e-6) 81 | binary_mask = binary_mask.reshape(batch_size, h, w) 82 | target = target.reshape(batch_size, h, w) 83 | 84 | mse_Loss = F.mse_loss(x_iou, iou, reduction='mean') 85 | 86 | return 20 * focal_loss + Dice_loss -------------------------------------------------------------------------------- /PlaneSAM/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | def box_cxcywh_to_xywh(x): 24 | x_c, y_c, w, h = x.unbind(-1) 25 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 26 | w, h] 27 | return torch.stack(b, dim=-1) 28 | 29 | 30 | def box_xyxy_to_xywh(x): 31 | x0, y0, x1, y1 = x.unbind(-1) 32 | b = [x0, y0, 33 | (x1 - x0), (y1 - y0)] 34 | return torch.stack(b, dim=-1) 35 | 36 | 37 | # modified from torchvision to also return the union 38 | def box_iou(boxes1, boxes2): 39 | area1 = box_area(boxes1) 40 | area2 = box_area(boxes2) 41 | 42 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 43 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 44 | 45 | wh = (rb - lt).clamp(min=0) # [N,M,2] 46 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 47 | 48 | union = area1[:, None] + area2 - inter 49 | 50 | iou = inter / union 51 | return iou, union 52 | 53 | 54 | def generalized_box_iou(boxes1, boxes2): 55 | """ 56 | Generalized IoU from https://giou.stanford.edu/ 57 | 58 | The boxes should be in [x0, y0, x1, y1] format 59 | 60 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 61 | and M = len(boxes2) 62 | """ 63 | # degenerate boxes gives inf / nan results 64 | # so do an early check 65 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 66 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 67 | iou, union = box_iou(boxes1, boxes2) 68 | 69 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 70 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 71 | 72 | wh = (rb - lt).clamp(min=0) # [N,M,2] 73 | area = wh[:, :, 0] * wh[:, :, 1] 74 | 75 | return iou - (area - union) / area 76 | 77 | 78 | def masks_to_boxes(masks): 79 | """Compute the bounding boxes around the provided masks 80 | 81 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 82 | 83 | Returns a [N, 4] tensors, with the boxes in xyxy format 84 | """ 85 | if masks.numel() == 0: 86 | return torch.zeros((0, 4), device=masks.device) 87 | 88 | h, w = masks.shape[-2:] 89 | 90 | y = torch.arange(0, h, dtype=torch.float, device=masks.device) 91 | x = torch.arange(0, w, dtype=torch.float, device=masks.device) 92 | y, x = torch.meshgrid(y, x) 93 | 94 | x_mask = (masks * x.unsqueeze(0)) 95 | x_max = x_mask.flatten(1).max(-1)[0] 96 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 97 | 98 | y_mask = (masks * y.unsqueeze(0)) 99 | y_max = y_mask.flatten(1).max(-1)[0] 100 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 101 | 102 | return torch.stack([x_min, y_min, x_max, y_max], 1) 103 | -------------------------------------------------------------------------------- /PlaneSAM/utils/train_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from sklearn.metrics import adjusted_rand_score 5 | 6 | 7 | def print_log(global_step, epoch, local_count, count_inter, dataset_size, loss, time_inter): 8 | print('Step: {:>5} Train Epoch: {:>3} [{:>4}/{:>4} ({:3.1f}%)] ' 9 | 'Loss: {:.6f} [{:.2f}s every {:>4} data]'.format( 10 | global_step, epoch, local_count, dataset_size, 11 | 100. * local_count / dataset_size, loss.data, time_inter, count_inter)) 12 | 13 | 14 | def save_ckpt(ckpt_dir, model, optimizer, global_step, epoch, local_count, num_train): 15 | # usually this happens only on the start of a epoch 16 | epoch_float = epoch + (local_count / num_train) 17 | state = { 18 | 'global_step': global_step, 19 | 'epoch': epoch_float, 20 | 'state_dict': model.state_dict(), 21 | 'optimizer': optimizer.state_dict(), 22 | } 23 | ckpt_model_filename = "ckpt_epoch_{:0.2f}.pth".format(epoch_float) 24 | path = os.path.join(ckpt_dir, ckpt_model_filename) 25 | torch.save(state, path) 26 | print('{:>2} has been successfully saved'.format(path)) 27 | 28 | 29 | def load_ckpt(model, optimizer, model_file, device, NUM_GPUS): 30 | if os.path.isfile(model_file): 31 | print("=> loading checkpoint '{}'".format(model_file)) 32 | if device.type == 'cuda': 33 | checkpoint = torch.load(model_file) 34 | else: 35 | checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) 36 | model_dict = checkpoint['state_dict'] 37 | model_dict_ = {} 38 | if NUM_GPUS > 1: 39 | model.load_state_dict(model_dict) 40 | else: 41 | for k, v in model_dict.items(): 42 | k_ = k.replace('module.', '') 43 | model_dict_[k_] = v 44 | model.load_state_dict(model_dict_) 45 | if optimizer: 46 | optimizer.load_state_dict(checkpoint['optimizer']) 47 | else: 48 | print("=> no checkpoint found at '{}'".format(model_file)) 49 | exit(0) 50 | 51 | 52 | # 计算指标 53 | def compute_iou(x, target): 54 | """ 55 | :param x: [B, H, W] 56 | :param target: [B, H, W] 57 | :return: iou 58 | """ 59 | b, h, w = x.shape 60 | x = x.reshape(b, -1) 61 | target = target.reshape(b, -1) 62 | intersection = torch.sum(torch.logical_and(x, target).float(), dim=1) 63 | union = torch.sum(torch.logical_or(x, target).float(), dim=1) 64 | # [B] 65 | iou = intersection / (union + 1e-9) 66 | return iou 67 | 68 | 69 | def compute_iou_sc(x, target): 70 | """ 71 | :param x: [B, H, W] 72 | :param target: [B, H, W] 73 | :return: iou 74 | """ 75 | b, h, w = x.shape 76 | x = x.reshape(b, -1) 77 | target = target.reshape(b, -1) 78 | intersection = np.sum(np.logical_and(x, target).astype(np.float32), axis=1) 79 | union = np.sum(np.logical_or(x, target).astype(np.float32), axis=1) 80 | iou = np.max(intersection / (union + 1e-9)) 81 | return iou 82 | 83 | 84 | def compute_acc(x, target): 85 | """ 86 | :param x: [B, H, W] 87 | :param target: [B, H, W] 88 | :return: acc 89 | """ 90 | b, h, w = x.shape 91 | x = x.reshape(b, -1) 92 | target = target.reshape(b, -1) 93 | 94 | acc = torch.mean((x == target).float(), dim=1) 95 | return acc 96 | 97 | 98 | def compute_RI(x, target): 99 | """ 100 | :param x: [B, H, W] 101 | :param target: [B, H, W] 102 | :return: RI 103 | """ 104 | num_planes, h, w = x.shape 105 | pred_map = np.zeros((h, w), dtype=np.uint8) 106 | gt_map = np.zeros((h, w), dtype=np.uint8) 107 | 108 | # [H, W] 109 | for i in range(num_planes): 110 | pred_map[x[i] == 1] = i + 1 111 | gt_map[target[i] == 1] = i + 1 112 | 113 | pred_map = pred_map.flatten() 114 | gt_map = gt_map.flatten() 115 | RI = adjusted_rand_score(pred_map, gt_map) 116 | 117 | return RI 118 | 119 | 120 | def compute_SC(x, target): 121 | """ 122 | :param x: [B, H, W] 123 | :param target: [B, H, W] 124 | :return: RI 125 | """ 126 | iou_1 = [] 127 | iou_2 = [] 128 | num_planes, h, w = x.shape 129 | for per_plane in x: 130 | plane = np.repeat(per_plane[np.newaxis, ...], num_planes, axis=0) 131 | iou_1.append(compute_iou_sc(plane, target)) 132 | iou_1 = sum(iou_1) / len(iou_1) 133 | 134 | for per_plane in target: 135 | plane = np.repeat(per_plane[np.newaxis, ...], num_planes, axis=0) 136 | iou_2.append(compute_iou_sc(plane, x)) 137 | iou_2 = sum(iou_2) / len(iou_2) 138 | 139 | return (iou_1 + iou_2) / 2 -------------------------------------------------------------------------------- /PlaneSAM/utils/eval_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import adjusted_rand_score, mutual_info_score 4 | from torchvision.ops import box_iou 5 | 6 | 7 | def compute_iou(x, target): 8 | """ 9 | :param x: [B, H, W] 10 | :param target: [B, H, W] 11 | :return: iou 12 | """ 13 | b, h, w = x.shape 14 | x = x.reshape(b, -1) 15 | target = target.reshape(b, -1) 16 | intersection = torch.sum(torch.logical_and(x, target).float(), dim=1) 17 | union = torch.sum(torch.logical_or(x, target).float(), dim=1) 18 | # [B] 19 | iou = intersection / (union + 1e-9) 20 | return iou 21 | 22 | 23 | def compute_iou_sc(x, target): 24 | """ 25 | :param x: [B, H, W] 26 | :param target: [B, H, W] 27 | :return: iou 28 | """ 29 | b, h, w = x.shape 30 | x = x.reshape(b, -1) 31 | target = target.reshape(b, -1) 32 | intersection = np.sum(np.logical_and(x, target).astype(np.float32), axis=1) 33 | union = np.sum(np.logical_or(x, target).astype(np.float32), axis=1) 34 | iou = np.max(intersection / (union + 1e-9)) 35 | return iou 36 | 37 | 38 | def compute_acc(x, target): 39 | """ 40 | :param x: [B, H, W] 41 | :param target: [B, H, W] 42 | :return: acc 43 | """ 44 | b, h, w = x.shape 45 | x = x.reshape(b, -1) 46 | target = target.reshape(b, -1) 47 | 48 | acc = torch.mean((x == target).float(), dim=1) 49 | return acc 50 | 51 | 52 | def masks_to_map(masks): 53 | """ 54 | :param masks: [B, H, W] 55 | :return: [H, W] 56 | """ 57 | num_planes, h, w = masks.shape 58 | semantic_map = np.zeros((h, w), dtype=np.uint8) 59 | # [H, W] 60 | for i in range(num_planes): 61 | semantic_map[masks[i] == 1] = i + 1 62 | 63 | return semantic_map 64 | 65 | 66 | def match_boxes_gt(pred_boxes, gt_boxes): 67 | # pred_boxes has been sorted by score 68 | num_gts = gt_boxes.shape[0] 69 | num_preds = pred_boxes.shape[0] 70 | device = gt_boxes.device 71 | batched_points = torch.full((num_gts, 2, 2), fill_value=-1, dtype=torch.float32, device=device) 72 | batched_labels = torch.full((num_gts, 2), fill_value=-1, dtype=torch.float32, device=device) 73 | iou_matrix = box_iou(pred_boxes, gt_boxes) 74 | 75 | _, best_match_indices = iou_matrix.max(dim=1) 76 | unique_masks = torch.zeros(num_preds, dtype=torch.bool, device=device) 77 | tmp = [] 78 | for i, t in enumerate(best_match_indices): 79 | if t not in tmp: 80 | tmp.append(t) 81 | unique_masks[i] = True 82 | 83 | pred_boxes = pred_boxes[unique_masks] 84 | best_match_indices = best_match_indices[unique_masks] 85 | 86 | batched_points[best_match_indices, 0] = pred_boxes[:, :2] 87 | batched_points[best_match_indices, 1] = pred_boxes[:, 2:] 88 | batched_labels[best_match_indices, 0] = 2 89 | batched_labels[best_match_indices, 1] = 3 90 | 91 | return batched_points.unsqueeze(1), batched_labels.unsqueeze(1) 92 | 93 | 94 | def evaluateMasks(predMasks, gtMasks): 95 | """ 96 | :param predMasks: [N, H, W] 97 | :param gtMasks: [N, H, W] 98 | :return: 99 | """ 100 | valid_mask = (gtMasks.max(0)[0]).unsqueeze(0) 101 | 102 | gtMasks = torch.cat([gtMasks, torch.clamp(1 - gtMasks.sum(0, keepdim=True), min=0)], dim=0) # M+1, H, W 103 | predMasks = torch.cat([predMasks, torch.clamp(1 - predMasks.sum(0, keepdim=True), min=0)], dim=0) # N+1, H, W 104 | 105 | intersection = (gtMasks.unsqueeze(1) * predMasks * valid_mask).sum(-1).sum(-1).float() 106 | union = (torch.max(gtMasks.unsqueeze(1), predMasks) * valid_mask).sum(-1).sum(-1).float() 107 | 108 | N = intersection.sum() 109 | 110 | RI = 1 - ((intersection.sum(0).pow(2).sum() + intersection.sum(1).pow(2).sum()) / 2 - intersection.pow(2).sum()) / ( 111 | N * (N - 1) / 2) 112 | joint = intersection / N 113 | marginal_2 = joint.sum(0) 114 | marginal_1 = joint.sum(1) 115 | H_1 = (-marginal_1 * torch.log2(marginal_1 + (marginal_1 == 0).float())).sum() 116 | H_2 = (-marginal_2 * torch.log2(marginal_2 + (marginal_2 == 0).float())).sum() 117 | 118 | B = (marginal_1.unsqueeze(-1) * marginal_2) 119 | log2_quotient = torch.log2(torch.clamp(joint, 1e-8) / torch.clamp(B, 1e-8)) * (torch.min(joint, B) > 1e-8).float() 120 | MI = (joint * log2_quotient).sum() 121 | voi = H_1 + H_2 - 2 * MI 122 | 123 | IOU = intersection / torch.clamp(union, min=1) 124 | SC = ((IOU.max(-1)[0] * torch.clamp((gtMasks * valid_mask).sum(-1).sum(-1), min=1e-4)).sum() / N + ( 125 | IOU.max(0)[0] * torch.clamp((predMasks * valid_mask).sum(-1).sum(-1), min=1e-4)).sum() / N) / 2 126 | info = [RI.item(), voi.item(), SC.item()] 127 | 128 | return info -------------------------------------------------------------------------------- /PlaneSAM/backbone/mobilenetv2_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision.ops import misc 4 | 5 | 6 | def _make_divisible(ch, divisor=8, min_ch=None): 7 | """ 8 | This function is taken from the original tf repo. 9 | It ensures that all layers have a channel number that is divisible by 8 10 | It can be seen here: 11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 12 | """ 13 | if min_ch is None: 14 | min_ch = divisor 15 | new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor) 16 | # Make sure that round down does not go down by more than 10%. 17 | if new_ch < 0.9 * ch: 18 | new_ch += divisor 19 | return new_ch 20 | 21 | 22 | class ConvBNReLU(nn.Sequential): 23 | def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, groups=1, norm_layer=None): 24 | padding = (kernel_size - 1) // 2 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | super(ConvBNReLU, self).__init__( 28 | nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, groups=groups, bias=False), 29 | norm_layer(out_channel), 30 | nn.ReLU6(inplace=True) 31 | ) 32 | 33 | 34 | class InvertedResidual(nn.Module): 35 | def __init__(self, in_channel, out_channel, stride, expand_ratio, norm_layer=None): 36 | super(InvertedResidual, self).__init__() 37 | hidden_channel = in_channel * expand_ratio 38 | self.use_shortcut = stride == 1 and in_channel == out_channel 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | 42 | layers = [] 43 | if expand_ratio != 1: 44 | # 1x1 pointwise conv 45 | layers.append(ConvBNReLU(in_channel, hidden_channel, kernel_size=1, norm_layer=norm_layer)) 46 | layers.extend([ 47 | # 3x3 depthwise conv 48 | ConvBNReLU(hidden_channel, hidden_channel, stride=stride, groups=hidden_channel, norm_layer=norm_layer), 49 | # 1x1 pointwise conv(linear) 50 | nn.Conv2d(hidden_channel, out_channel, kernel_size=1, bias=False), 51 | norm_layer(out_channel), 52 | ]) 53 | 54 | self.conv = nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | if self.use_shortcut: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class MobileNetV2(nn.Module): 64 | def __init__(self, num_classes=1000, alpha=1.0, round_nearest=8, weights_path=None, norm_layer=None): 65 | super(MobileNetV2, self).__init__() 66 | block = InvertedResidual 67 | input_channel = _make_divisible(32 * alpha, round_nearest) 68 | last_channel = _make_divisible(1280 * alpha, round_nearest) 69 | 70 | if norm_layer is None: 71 | norm_layer = nn.BatchNorm2d 72 | 73 | inverted_residual_setting = [ 74 | # t, c, n, s 75 | [1, 16, 1, 1], 76 | [6, 24, 2, 2], 77 | [6, 32, 3, 2], 78 | [6, 64, 4, 2], 79 | [6, 96, 3, 1], 80 | [6, 160, 3, 2], 81 | [6, 320, 1, 1], 82 | ] 83 | 84 | features = [] 85 | # conv1 layer 86 | features.append(ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)) 87 | # building inverted residual residual blockes 88 | for t, c, n, s in inverted_residual_setting: 89 | output_channel = _make_divisible(c * alpha, round_nearest) 90 | for i in range(n): 91 | stride = s if i == 0 else 1 92 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 93 | input_channel = output_channel 94 | # building last several layers 95 | features.append(ConvBNReLU(input_channel, last_channel, 1, norm_layer=norm_layer)) 96 | # combine feature layers 97 | self.features = nn.Sequential(*features) 98 | 99 | # building classifier 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | self.classifier = nn.Sequential( 102 | nn.Dropout(0.2), 103 | nn.Linear(last_channel, num_classes) 104 | ) 105 | 106 | if weights_path is None: 107 | # weight initialization 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 111 | if m.bias is not None: 112 | nn.init.zeros_(m.bias) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | nn.init.ones_(m.weight) 115 | nn.init.zeros_(m.bias) 116 | elif isinstance(m, nn.Linear): 117 | nn.init.normal_(m.weight, 0, 0.01) 118 | nn.init.zeros_(m.bias) 119 | else: 120 | self.load_state_dict(torch.load(weights_path)) 121 | 122 | def forward(self, x): 123 | x = self.features(x) 124 | x = self.avgpool(x) 125 | x = torch.flatten(x, 1) 126 | x = self.classifier(x) 127 | return x 128 | -------------------------------------------------------------------------------- /PlaneSAM/S2D3DSDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import random 5 | import numpy as np 6 | from torchvision import transforms 7 | import skimage 8 | from pycocotools.coco import COCO 9 | from PIL import Image 10 | 11 | class S2d3dsDataset(Dataset): 12 | def __init__(self, datafolder, subset='train', transform=None): 13 | self.datafolder = datafolder 14 | self.subset = subset 15 | self.transform = transform 16 | assert subset in ['train', 'val'] 17 | if subset == 'train': 18 | self.coco = COCO(annotation_file=os.path.join(datafolder, 's2d3ds_train.json')) 19 | else: 20 | self.coco = COCO(annotation_file=os.path.join(datafolder, 's2d3ds_val.json')) 21 | self.image_ids = self.coco.getImgIds() 22 | 23 | def __len__(self): 24 | return len(self.image_ids) 25 | 26 | ''' 27 | return: 28 | image: [3, H, W] 29 | depth: [H, W] 30 | segmentation: [H, W] 31 | instance: [num_planes, H, W] 32 | ''' 33 | 34 | def __getitem__(self, index): 35 | image_id = self.image_ids[index] 36 | if self.subset == 'train': 37 | image_path = os.path.join(self.datafolder, 'images', self.coco.loadImgs(image_id)[0]['file_name']) 38 | else: 39 | image_path = os.path.join(self.datafolder, 'images_val', self.coco.loadImgs(image_id)[0]['file_name']) 40 | depth_path = image_path.replace('images', 'depths').replace('rgb', 'depth').replace('.jpg', '.png') 41 | image = np.array(Image.open(image_path), dtype=np.uint8) 42 | depth = np.array(Image.open(depth_path), dtype=np.float32) / 1000.0 43 | 44 | segmentation = np.zeros(image.shape[:2], dtype=np.uint8) 45 | annotation_id = self.coco.getAnnIds(imgIds=image_id) 46 | annotation = self.coco.loadAnns(annotation_id) 47 | for idx, i in enumerate(annotation): 48 | segmentation[self.coco.annToMask(i) > 0] = idx + 1 49 | idx += 1 50 | 51 | sample = {} 52 | if self.transform: 53 | sample = self.transform({ 54 | 'image': image, 55 | 'depth': depth, 56 | 'segmentation': segmentation 57 | }) 58 | image = sample['image'] 59 | depth = sample['depth'] 60 | segmentation = sample['segmentation'] 61 | 62 | mask = [] 63 | unique_idx = torch.unique(segmentation) 64 | unique_idx = [x for x in unique_idx if x] 65 | for i in unique_idx: 66 | mask.append((segmentation == i).float()) 67 | mask = torch.stack(mask) 68 | bbox = self.masks_to_bboxes(mask) 69 | num_planes = len(unique_idx) 70 | 71 | masks = torch.zeros(30, image.shape[1], image.shape[2], dtype=torch.float32) 72 | masks[:num_planes] = mask 73 | 74 | sample.update({ 75 | 'instance': masks, 76 | 'num_planes': num_planes, 77 | "data_path": image_path 78 | }) 79 | 80 | return sample 81 | 82 | def masks_to_bboxes(self, masks): 83 | """ 84 | 从掩码张量中计算边界框的左上和右下坐标 85 | 参数: 86 | masks: 形状为 [B, H, W] 的二进制掩码张量 87 | 返回值: 88 | bounding_boxes: 形状为 [B, 4] 的边界框坐标张量,包含左上和右下坐标 89 | """ 90 | batch_size, height, width = masks.size() 91 | device = masks 92 | bounding_boxes = torch.zeros((batch_size, 4), dtype=torch.float32) 93 | 94 | for b in range(batch_size): 95 | mask = masks[b] 96 | 97 | # 找到掩码的非零元素索引 98 | nonzero_indices = torch.nonzero(mask) 99 | 100 | if nonzero_indices.size(0) == 0: 101 | # 如果掩码中没有非零元素,则边界框坐标为零 102 | assert "no mask!" 103 | else: 104 | # 计算边界框的左上和右下坐标 105 | ymin = torch.min(nonzero_indices[:, 0]) 106 | xmin = torch.min(nonzero_indices[:, 1]) 107 | ymax = torch.max(nonzero_indices[:, 0]) 108 | xmax = torch.max(nonzero_indices[:, 1]) 109 | bounding_boxes[b] = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32) 110 | 111 | return bounding_boxes 112 | 113 | 114 | class ToTensor(object): 115 | def __call__(self, sample): 116 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 117 | # [H, W, C] -> [C, H, W], 像素值归一化到0-1之间 118 | image = transforms.ToTensor()(image) 119 | # [1, H, W] 120 | depth = transforms.ToTensor()(depth) 121 | return { 122 | 'image': image, 123 | 'depth': depth, 124 | 'segmentation': torch.from_numpy(segmentation.astype(np.int16)).float() 125 | } 126 | 127 | class RandomFlip(object): 128 | def __call__(self, sample): 129 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 130 | if random.random() > 0.5: 131 | image = np.fliplr(image).copy() 132 | depth = np.fliplr(depth).copy() 133 | segmentation = np.fliplr(segmentation).copy() 134 | 135 | return { 136 | 'image': image, 137 | 'depth': depth, 138 | 'segmentation': segmentation 139 | } -------------------------------------------------------------------------------- /PlaneSAM/Nyuv2Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import random 5 | import numpy as np 6 | from torchvision import transforms 7 | import skimage 8 | 9 | class Nyuv2Dataset(Dataset): 10 | def __init__(self, datafolder, subset='train', transform=None): 11 | self.datafolder = datafolder 12 | self.subset = subset 13 | self.transform = transform 14 | assert subset in ['train', 'val'] 15 | self.data_list = [os.path.join(datafolder, subset, x) for x in os.listdir(os.path.join(datafolder, subset))] 16 | 17 | def __len__(self): 18 | return len(self.data_list) 19 | 20 | ''' 21 | return: 22 | image: [3, H, W] 23 | depth: [H, W] 24 | segmentation: [H, W] 25 | instance: [num_planes, H, W] 26 | ''' 27 | def __getitem__(self, index): 28 | data_path = self.data_list[index] 29 | data = np.load(data_path, allow_pickle=True) 30 | data = np.load(self.data_list[index], allow_pickle=True) 31 | image = data[:, :, :3].astype(np.uint8) 32 | depth = data[:, :, 3] 33 | segmentation = data[:, :, 4].astype(np.uint8) 34 | 35 | sample = {} 36 | if self.transform: 37 | sample = self.transform({ 38 | 'image': image, 39 | 'depth': depth, 40 | 'segmentation': segmentation 41 | }) 42 | image = sample['image'] 43 | depth = sample['depth'] 44 | segmentation = sample['segmentation'] 45 | 46 | mask = [] 47 | unique_idx = torch.unique(segmentation) 48 | unique_idx = [x for x in unique_idx if x] 49 | for i in unique_idx: 50 | mask.append((segmentation == i).float()) 51 | mask = torch.stack(mask) 52 | bbox = self.masks_to_bboxes(mask) 53 | num_planes = len(unique_idx) 54 | 55 | masks = torch.zeros(30, image.shape[1], image.shape[2], dtype=torch.float32) 56 | masks[:num_planes] = mask 57 | 58 | sample.update({ 59 | 'instance': masks, 60 | 'num_planes': num_planes, 61 | 'data_path': data_path 62 | }) 63 | 64 | return sample 65 | 66 | def masks_to_bboxes(self, masks): 67 | """ 68 | 从掩码张量中计算边界框的左上和右下坐标 69 | 参数: 70 | masks: 形状为 [B, H, W] 的二进制掩码张量 71 | 返回值: 72 | bounding_boxes: 形状为 [B, 4] 的边界框坐标张量,包含左上和右下坐标 73 | """ 74 | batch_size, height, width = masks.size() 75 | device = masks 76 | bounding_boxes = torch.zeros((batch_size, 4), dtype=torch.float32) 77 | 78 | for b in range(batch_size): 79 | mask = masks[b] 80 | 81 | # 找到掩码的非零元素索引 82 | nonzero_indices = torch.nonzero(mask) 83 | 84 | if nonzero_indices.size(0) == 0: 85 | # 如果掩码中没有非零元素,则边界框坐标为零 86 | assert "no mask!" 87 | else: 88 | # 计算边界框的左上和右下坐标 89 | ymin = torch.min(nonzero_indices[:, 0]) 90 | xmin = torch.min(nonzero_indices[:, 1]) 91 | ymax = torch.max(nonzero_indices[:, 0]) 92 | xmax = torch.max(nonzero_indices[:, 1]) 93 | bounding_boxes[b] = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32) 94 | 95 | return bounding_boxes 96 | 97 | 98 | class ToTensor(object): 99 | def __call__(self, sample): 100 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 101 | # [H, W, C] -> [C, H, W], 像素值归一化到0-1之间 102 | image = transforms.ToTensor()(image) 103 | # [1, H, W] 104 | depth = transforms.ToTensor()(depth) 105 | return { 106 | 'image': image, 107 | 'depth': depth, 108 | 'segmentation': torch.from_numpy(segmentation.astype(np.int16)).float() 109 | } 110 | 111 | class RandomFlip(object): 112 | def __call__(self, sample): 113 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 114 | if random.random() > 0.5: 115 | image = np.fliplr(image).copy() 116 | depth = np.fliplr(depth).copy() 117 | segmentation = np.fliplr(segmentation).copy() 118 | 119 | return { 120 | 'image': image, 121 | 'depth': depth, 122 | 'segmentation': segmentation 123 | } 124 | 125 | # 随机缩放 126 | class RandomScale(object): 127 | def __init__(self, scale): 128 | self.scale_low = min(scale) 129 | self.scale_high = max(scale) 130 | 131 | def __call__(self, sample): 132 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 133 | 134 | target_scale = random.uniform(self.scale_low, self.scale_high) 135 | # (H, W, C) 136 | target_height = int(round(target_scale * image.shape[0])) 137 | target_width = int(round(target_scale * image.shape[1])) 138 | # Bi-linear 139 | image = skimage.transform.resize(image, (target_height, target_width), 140 | order=1, mode='reflect', preserve_range=True).astype(np.uint8) 141 | # Nearest-neighbor 142 | depth = skimage.transform.resize(depth, (target_height, target_width), 143 | order=0, mode='reflect', preserve_range=True).astype(np.uint8) 144 | segmentation = skimage.transform.resize(segmentation, (target_height, target_width), 145 | order=0, mode='reflect', preserve_range=True) 146 | 147 | return {'image': image, 'depth': depth, 'segmentation': segmentation} 148 | 149 | # 随机裁剪 150 | class RandomCrop(object): 151 | def __init__(self, th, tw): 152 | self.th = th 153 | self.tw = tw 154 | 155 | def __call__(self, sample): 156 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 157 | h = image.shape[0] 158 | w = image.shape[1] 159 | i = random.randint(0, h - self.th) 160 | j = random.randint(0, w - self.tw) 161 | 162 | return {'image': image[i:i + image_h, j:j + image_w, :], 163 | 'depth': depth[i:i + image_h, j:j + image_w], 164 | 'segmentation': segmentation[i:i + image_h, j:j + image_w]} -------------------------------------------------------------------------------- /PlaneSAM/utils/make_prompt.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | def make_point_pt(masks, num_samples=1): 5 | """ 6 | masks: [B, H, W] 7 | 8 | return: 9 | batched_points: [B, num_pt, per_pt_points, 2] 10 | batched_label: [B, num_pt, per_pt_points] 11 | """ 12 | device = masks.device 13 | batched_points = [] 14 | batched_labels = [] 15 | 16 | for mask in masks: 17 | # instance_ids, counts = torch.unique(mask, return_counts=True) 18 | index = torch.nonzero(mask, as_tuple=True) 19 | num_points = len(index[0]) 20 | if num_points >= num_samples: 21 | # diff = num_samples - num_points if num_samples > num_points else 0 22 | # num_samples = num_points if num_samples > num_points else num_samples 23 | select_idx = torch.tensor(random.sample(range(num_points), num_samples)) 24 | # [num_samples, 2] 25 | point_coords = torch.zeros((num_samples, 2)) 26 | # h -> y and w -> x 27 | point_coords[..., 1], point_coords[..., 0] = index[0][select_idx], index[1][select_idx] 28 | point_coords = point_coords[None, ...] 29 | # 前景是1,背景为0 30 | point_labels = torch.ones(1, num_samples) 31 | 32 | # 如果diff存在,就要确保batch的格式一致 33 | # if diff: 34 | # diff_point_coords = torch.full((1, diff, 2), fill_value=-1) 35 | # diff_point_labels = torch.full((1, diff), fill_value=-1) 36 | # point_coords = torch.cat((point_coords, diff_point_coords), dim=1) 37 | # point_labels = torch.cat((point_labels, diff_point_labels), dim=1) 38 | 39 | # 没有实例 40 | else: 41 | # 点提示全部用-1填充 42 | point_coords = torch.full((1, num_samples, 2), fill_value=-1) 43 | point_labels = torch.full((1, num_samples), fill_value=-1) 44 | 45 | batched_points.append(point_coords) 46 | batched_labels.append(point_labels) 47 | 48 | batched_points = torch.stack(batched_points, dim=0).to(device) 49 | batched_labels = torch.stack(batched_labels, dim=0).to(device) 50 | 51 | return batched_points, batched_labels 52 | 53 | 54 | def make_box_pt(batched_masks, noise_ratio=0.): 55 | """ 56 | batched_mask: [B, H, W] 57 | 58 | return: 59 | 左上点的标签是2,右下点的标签是3 60 | batched_points: [B, num_pt, per_pt_points, 2] 61 | batched_label: [B, num_pt, per_pt_points] 62 | """ 63 | device = batched_masks.device 64 | 65 | batched_points = [] 66 | batched_labels = [] 67 | 68 | for per_mask in batched_masks: 69 | indices = torch.nonzero(per_mask) 70 | h, w = per_mask.shape 71 | 72 | # 存在前景 73 | if indices.numel(): 74 | # 计算边界框的左下和右上点的坐标 75 | # bbox_min:[x_min, y_min] 76 | bbox_min = torch.min(indices, dim=0).values.flip(0) 77 | bbox_max = torch.max(indices, dim=0).values.flip(0) 78 | 79 | # 加10%边界框长度的噪声 80 | bbox_w, bbox_h = bbox_max - bbox_min 81 | noise_x1 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_w) 82 | noise_y1 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_h) 83 | noise_x2 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_w) 84 | noise_y2 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_h) 85 | bbox_min[0] = bbox_min[0] + noise_x1 86 | bbox_min[1] = bbox_min[1] + noise_y1 87 | bbox_max[0] = bbox_max[0] + noise_x2 88 | bbox_max[1] = bbox_max[1] + noise_y2 89 | bbox_min[0] = torch.clip(bbox_min[0], 0, w) 90 | bbox_min[1] = torch.clip(bbox_min[1], 0, h) 91 | bbox_max[0] = torch.clip(bbox_max[0], 0, w) 92 | bbox_max[1] = torch.clip(bbox_max[1], 0, h) 93 | 94 | # [num_pt, 2, 2] 95 | per_box_pt = torch.cat((bbox_min[None, None, :], bbox_max[None, None, :]), dim=1) 96 | 97 | # 左上对应最小点,右下对应最大点 98 | bottomright_label = torch.full((1, 1), fill_value=3).to(device) 99 | topleft_label = torch.full((1, 1), fill_value=2).to(device) 100 | # [1, 2, 1] 101 | per_box_label = torch.cat((topleft_label, bottomright_label), dim=1) 102 | 103 | batched_points.append(per_box_pt) 104 | batched_labels.append(per_box_label) 105 | else: 106 | batched_points.append(torch.full((1, 2, 2), fill_value=-1).to(device)) 107 | batched_labels.append(torch.full((1, 2), fill_value=-1).to(device)) 108 | 109 | batched_points = torch.stack(batched_points, dim=0) 110 | batched_labels = torch.stack(batched_labels, dim=0) 111 | 112 | return batched_points, batched_labels 113 | 114 | 115 | def preprocess(masks, num_points=1, box=False, noise_ratio=0.): 116 | """ 117 | masks: [B, H, W] 118 | 119 | return: 120 | point_prompts: [B, num_pts, num_points, 2] 121 | prompt_labels: [B, num_pts, num_points] 122 | """ 123 | assert 0 <= num_points + 2 * box <= 6, "num_points shouldn't be greater than 6 or less than 0!" 124 | 125 | device = masks.device 126 | 127 | batched_points = None 128 | batched_pt_labels = None 129 | batched_boxes = None 130 | batched_box_labels = None 131 | 132 | if num_points: 133 | batched_points, batched_pt_labels = make_point_pt(masks, num_samples=num_points) 134 | if box: 135 | batched_boxes, batched_box_labels = make_box_pt(masks, noise_ratio) 136 | 137 | # 分配到指定设备上 138 | if batched_points is not None: 139 | batched_points = batched_points.to(device) 140 | batched_pt_labels = batched_pt_labels.to(device) 141 | if batched_boxes is not None: 142 | batched_boxes = batched_boxes.to(device) 143 | batched_box_labels = batched_box_labels.to(device) 144 | 145 | # 分情况返回值 146 | if batched_points is not None and batched_boxes is not None: 147 | # [B, num_pts, num_points, 2] 148 | # [B, num_pts, num_points] 149 | return torch.cat((batched_boxes, batched_points), dim=2), torch.cat( 150 | (batched_box_labels, batched_pt_labels), dim=2) 151 | elif batched_points is not None: 152 | return batched_points, batched_pt_labels 153 | elif batched_boxes is not None: 154 | return batched_boxes, batched_box_labels -------------------------------------------------------------------------------- /PlaneSAM/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim 9 | import torchvision.transforms as transforms 10 | from model import build_efficient_sam_vitt 11 | from FasterRCNN.build_FasterRCNN import create_model 12 | from PlaneDataset import PlaneDataset 13 | from Nyuv2Dataset import Nyuv2Dataset, ToTensor 14 | from S2D3DSDataset import S2d3dsDataset, ToTensor 15 | from torch.utils.data import DataLoader 16 | from utils.make_prompt import preprocess 17 | from utils.train_tools import load_ckpt 18 | from utils.eval_tools import evaluateMasks, match_boxes_gt 19 | from utils.box_ops import masks_to_boxes 20 | 21 | parser = argparse.ArgumentParser(description='Segment Any Planes') 22 | parser.add_argument('--data-dir', default='ScanNet', metavar='DIR', 23 | help='path to dataset') 24 | parser.add_argument('--num-workers', default=8, type=int, metavar='N', 25 | help='number of data loading workers (default: 8)') 26 | parser.add_argument('-o', '--output', default='result', metavar='DIR', 27 | help='path to output') 28 | parser.add_argument('--cuda', action='store_true', default=True, 29 | help='enables CUDA training') 30 | parser.add_argument('--last-ckpt', default='model/PlaneSAM.pth', type=str, metavar='PATH', 31 | help='path to latest checkpoint (default: none)') 32 | parser.add_argument('--detector-ckpt', default='weights/FasterRCNN.pth', type=str, metavar='PATH', 33 | help='path to detector checkpoint (default: none)') 34 | 35 | args = parser.parse_args() 36 | device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu") 37 | 38 | 39 | def inference(): 40 | model = build_efficient_sam_vitt() 41 | detector = create_model(num_classes=2, load_pretrain_weights=False) 42 | detector.load_state_dict(torch.load(args.detector_ckpt)['model']) 43 | 44 | # if torch.cuda.device_count() > 1: 45 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 46 | # model = nn.DataParallel(model) 47 | # detector = nn.DataParallel(detector) 48 | 49 | model.to(device) 50 | detector.to(device) 51 | 52 | # eval scannet 53 | val_data = PlaneDataset(subset="val", 54 | transform=transforms.Compose([ 55 | transforms.ToTensor()]), 56 | root_dir=args.data_dir 57 | ) 58 | val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 59 | pin_memory=False) 60 | 61 | # eval mp3d and synthetic 62 | # val_data = Nyuv2Dataset(subset="val", 63 | # transform=transforms.Compose([ 64 | # ToTensor()]), 65 | # datafolder=args.data_dir 66 | # ) 67 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 68 | # pin_memory=False) 69 | 70 | # eval s2d3ds 71 | # val_data = S2d3dsDataset(subset="val", 72 | # transform=transforms.Compose([ 73 | # ToTensor()]), 74 | # datafolder=args.data_dir 75 | # ) 76 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 77 | # pin_memory=False) 78 | 79 | if args.last_ckpt: 80 | load_ckpt(model, None, args.last_ckpt, device, NUM_GPUS=1) 81 | else: 82 | print('no ckpt!') 83 | 84 | mRI = .0 85 | mSC = .0 86 | mVoI = .0 87 | num = 0 88 | total_time = 0 89 | model.eval() 90 | detector.eval() 91 | with torch.no_grad(): 92 | for batch_idx, sample in enumerate(val_loader): 93 | num_planes = sample['num_planes'][0] 94 | image = sample['image'].to(device) 95 | target = sample['instance'].to(device) 96 | target = target[:, :num_planes, :, :].permute(1, 0, 2, 3) # [num_planes, B, H, W] 97 | gt_boxes = masks_to_boxes(target) 98 | depth = sample['depth'].to(device) 99 | 100 | # use gt box 101 | # batch_points, batch_labels = preprocess(target.flatten(0, 1), num_points=0, box=True) 102 | 103 | # 预测边界框 to prompt 104 | start_time = time.time() 105 | outputs = detector(image)[0] 106 | boxes, scores = outputs['boxes'], outputs['scores'] 107 | batch_points, batch_labels = match_boxes_gt(boxes, gt_boxes) 108 | 109 | # union prompt 110 | pred_masks, pred_ious = model(image, depth, batch_points, batch_labels) 111 | end_time = time.time() 112 | total_time += end_time - start_time 113 | pred_masks = pred_masks.permute(1, 0, 2, 3, 4) 114 | pred_ious = pred_ious.permute(1, 0, 2) 115 | 116 | # # use preprocess 117 | # batch_points, batch_labels = preprocess(target.flatten(0, 1), num_points=0, box=True) 118 | 119 | # 二值掩码 120 | pred_masks = (pred_masks > 0.).float() 121 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 122 | # [3, B, H, W] 123 | # [3, B] 124 | pred_masks = pred_masks.view(b, -1, h, w).permute(1, 0, 2, 3) 125 | pred_ious = pred_ious.view(b, -1).permute(1, 0) 126 | best_id = torch.argmax(pred_ious, dim=0) 127 | # [num_preds, H, W] 128 | best_mask = pred_masks[best_id, torch.arange(b)] 129 | target = target.squeeze(1) 130 | 131 | # 计算指标 132 | RI, VoI, SC = evaluateMasks(best_mask, target) 133 | mRI += RI 134 | mSC += SC 135 | mVoI += VoI 136 | num += 1 137 | 138 | print(f"iter: {batch_idx} mRI: {RI:.4f} mSC: {SC:.4f} VoI: {VoI:.4f}") 139 | 140 | mRI = mRI / num 141 | mSC = mSC / num 142 | mVoI = mVoI / num 143 | print(f"mRI: {mRI:.3f} mSC: {mSC:.3f} VoI: {mVoI:.3f}") 144 | print("img/s:", len(val_loader) / total_time) 145 | 146 | 147 | if __name__ == '__main__': 148 | if not os.path.exists(args.output): 149 | os.mkdir(args.output) 150 | inference() 151 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/boxes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | from torch import Tensor 4 | import torchvision 5 | 6 | 7 | def nms(boxes, scores, iou_threshold): 8 | # type: (Tensor, Tensor, float) -> Tensor 9 | """ 10 | Performs non-maximum suppression (NMS) on the boxes according 11 | to their intersection-over-union (IoU). 12 | 13 | NMS iteratively removes lower scoring boxes which have an 14 | IoU greater than iou_threshold with another (higher scoring) 15 | box. 16 | 17 | Parameters 18 | ---------- 19 | boxes : Tensor[N, 4]) 20 | boxes to perform NMS on. They 21 | are expected to be in (x1, y1, x2, y2) format 22 | scores : Tensor[N] 23 | scores for each one of the boxes 24 | iou_threshold : float 25 | discards all overlapping 26 | boxes with IoU > iou_threshold 27 | 28 | Returns 29 | ------- 30 | keep : Tensor 31 | int64 tensor with the indices 32 | of the elements that have been kept 33 | by NMS, sorted in decreasing order of scores 34 | """ 35 | return torch.ops.torchvision.nms(boxes, scores, iou_threshold) 36 | 37 | 38 | def batched_nms(boxes, scores, idxs, iou_threshold): 39 | # type: (Tensor, Tensor, Tensor, float) -> Tensor 40 | """ 41 | Performs non-maximum suppression in a batched fashion. 42 | 43 | Each index value correspond to a category, and NMS 44 | will not be applied between elements of different categories. 45 | 46 | Parameters 47 | ---------- 48 | boxes : Tensor[N, 4] 49 | boxes where NMS will be performed. They 50 | are expected to be in (x1, y1, x2, y2) format 51 | scores : Tensor[N] 52 | scores for each one of the boxes 53 | idxs : Tensor[N] 54 | indices of the categories for each one of the boxes. 55 | iou_threshold : float 56 | discards all overlapping boxes 57 | with IoU < iou_threshold 58 | 59 | Returns 60 | ------- 61 | keep : Tensor 62 | int64 tensor with the indices of 63 | the elements that have been kept by NMS, sorted 64 | in decreasing order of scores 65 | """ 66 | if boxes.numel() == 0: 67 | return torch.empty((0,), dtype=torch.int64, device=boxes.device) 68 | 69 | # strategy: in order to perform NMS independently per class. 70 | # we add an offset to all the boxes. The offset is dependent 71 | # only on the class idx, and is large enough so that boxes 72 | # from different classes do not overlap 73 | # 获取所有boxes中最大的坐标值(xmin, ymin, xmax, ymax) 74 | max_coordinate = boxes.max() 75 | 76 | # to(): Performs Tensor dtype and/or device conversion 77 | # 为每一个类别/每一层生成一个很大的偏移量 78 | # 这里的to只是让生成tensor的dytpe和device与boxes保持一致 79 | offsets = idxs.to(boxes) * (max_coordinate + 1) 80 | # boxes加上对应层的偏移量后,保证不同类别/层之间boxes不会有重合的现象 81 | boxes_for_nms = boxes + offsets[:, None] 82 | keep = nms(boxes_for_nms, scores, iou_threshold) 83 | return keep 84 | 85 | 86 | def remove_small_boxes(boxes, min_size): 87 | # type: (Tensor, float) -> Tensor 88 | """ 89 | Remove boxes which contains at least one side smaller than min_size. 90 | 移除宽高小于指定阈值的索引 91 | Arguments: 92 | boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format 93 | min_size (float): minimum size 94 | 95 | Returns: 96 | keep (Tensor[K]): indices of the boxes that have both sides 97 | larger than min_size 98 | """ 99 | ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] # 预测boxes的宽和高 100 | # keep = (ws >= min_size) & (hs >= min_size) # 当满足宽,高都大于给定阈值时为True 101 | keep = torch.logical_and(torch.ge(ws, min_size), torch.ge(hs, min_size)) 102 | # nonzero(): Returns a tensor containing the indices of all non-zero elements of input 103 | # keep = keep.nonzero().squeeze(1) 104 | keep = torch.where(keep)[0] 105 | return keep 106 | 107 | 108 | def clip_boxes_to_image(boxes, size): 109 | # type: (Tensor, Tuple[int, int]) -> Tensor 110 | """ 111 | Clip boxes so that they lie inside an image of size `size`. 112 | 裁剪预测的boxes信息,将越界的坐标调整到图片边界上 113 | 114 | Arguments: 115 | boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format 116 | size (Tuple[height, width]): size of the image 117 | 118 | Returns: 119 | clipped_boxes (Tensor[N, 4]) 120 | """ 121 | dim = boxes.dim() 122 | boxes_x = boxes[..., 0::2] # x1, x2 123 | boxes_y = boxes[..., 1::2] # y1, y2 124 | height, width = size 125 | 126 | if torchvision._is_tracing(): 127 | boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) 128 | boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) 129 | boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) 130 | boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) 131 | else: 132 | boxes_x = boxes_x.clamp(min=0, max=width) # 限制x坐标范围在[0,width]之间 133 | boxes_y = boxes_y.clamp(min=0, max=height) # 限制y坐标范围在[0,height]之间 134 | 135 | clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) 136 | return clipped_boxes.reshape(boxes.shape) 137 | 138 | 139 | def box_area(boxes): 140 | """ 141 | Computes the area of a set of bounding boxes, which are specified by its 142 | (x1, y1, x2, y2) coordinates. 143 | 144 | Arguments: 145 | boxes (Tensor[N, 4]): boxes for which the area will be computed. They 146 | are expected to be in (x1, y1, x2, y2) format 147 | 148 | Returns: 149 | area (Tensor[N]): area for each box 150 | """ 151 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 152 | 153 | 154 | def box_iou(boxes1, boxes2): 155 | """ 156 | Return intersection-over-union (Jaccard index) of boxes. 157 | 158 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 159 | 160 | Arguments: 161 | boxes1 (Tensor[N, 4]) 162 | boxes2 (Tensor[M, 4]) 163 | 164 | Returns: 165 | iou (Tensor[N, M]): the NxM matrix containing the pairwise 166 | IoU values for every element in boxes1 and boxes2 167 | """ 168 | area1 = box_area(boxes1) 169 | area2 = box_area(boxes2) 170 | 171 | # When the shapes do not match, 172 | # the shape of the returned output tensor follows the broadcasting rules 173 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # left-top [N,M,2] 174 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # right-bottom [N,M,2] 175 | 176 | wh = (rb - lt).clamp(min=0) # [N,M,2] 177 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 178 | 179 | iou = inter / (area1[:, None] + area2 - inter) 180 | return iou 181 | 182 | -------------------------------------------------------------------------------- /PlaneSAM/PlaneDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | from PIL import Image 7 | from random import randint 8 | 9 | 10 | 11 | class PlaneDataset(Dataset): 12 | def __init__(self, subset='train', transform=None, root_dir=None): 13 | assert subset in ['train', 'val'] 14 | self.subset = subset 15 | self.transform = transform 16 | self.root_dir = os.path.join(root_dir, subset) 17 | self.txt_file = os.path.join(root_dir, subset + '.txt') 18 | 19 | self.data_list = [line.strip() for line in open(self.txt_file, 'r').readlines()] 20 | self.precompute_K_inv_dot_xy_1() 21 | 22 | def get_plane_parameters(self, plane, plane_nums, segmentation): 23 | valid_region = segmentation != 20 24 | 25 | plane = plane[:plane_nums] 26 | 27 | tmp = plane[:, 1].copy() 28 | plane[:, 1] = -plane[:, 2] 29 | plane[:, 2] = tmp 30 | 31 | # convert plane from n * d to n / d 32 | plane_d = np.linalg.norm(plane, axis=1) 33 | # normalize 34 | plane /= plane_d.reshape(-1, 1) 35 | # n / d 36 | plane /= plane_d.reshape(-1, 1) 37 | 38 | h, w = segmentation.shape 39 | plane_parameters = np.ones((3, h, w)) 40 | for i in range(h): 41 | for j in range(w): 42 | d = segmentation[i, j] 43 | if d >= 20: continue 44 | plane_parameters[:, i, j] = plane[d, :] 45 | 46 | # plane_instance parameter, padding zero to fix size 47 | plane_instance_parameter = np.concatenate((plane, np.zeros((20 - plane.shape[0], 3))), axis=0) 48 | return plane_parameters, valid_region, plane_instance_parameter 49 | 50 | def precompute_K_inv_dot_xy_1(self, h=192, w=256): 51 | focal_length = 517.97 52 | offset_x = 320 53 | offset_y = 240 54 | 55 | K = [[focal_length, 0, offset_x], 56 | [0, focal_length, offset_y], 57 | [0, 0, 1]] 58 | 59 | K_inv = np.linalg.inv(np.array(K)) 60 | self.K_inv = K_inv 61 | 62 | K_inv_dot_xy_1 = np.zeros((3, h, w)) 63 | for y in range(h): 64 | for x in range(w): 65 | yy = float(y) / h * 480 66 | xx = float(x) / w * 640 67 | 68 | ray = np.dot(self.K_inv, 69 | np.array([xx, yy, 1]).reshape(3, 1)) 70 | K_inv_dot_xy_1[:, y, x] = ray[:, 0] 71 | 72 | # precompute to speed up processing 73 | self.K_inv_dot_xy_1 = K_inv_dot_xy_1 74 | 75 | def plane2depth(self, plane_parameters, num_planes, segmentation, gt_depth, h=192, w=256): 76 | 77 | depth_map = 1. / np.sum(self.K_inv_dot_xy_1.reshape(3, -1) * plane_parameters.reshape(3, -1), axis=0) 78 | depth_map = depth_map.reshape(h, w) 79 | 80 | # replace non planer region depth using sensor depth map 81 | # 做了一个深度修复,并把非平面区域的深度设为0 82 | depth_map[segmentation == 20] = gt_depth[segmentation == 20] 83 | return depth_map 84 | 85 | def __getitem__(self, index): 86 | if self.subset == 'train': 87 | data_path = self.data_list[index] 88 | else: 89 | data_path = str(index) + '.npz' 90 | data_path = os.path.join(self.root_dir, data_path) 91 | data = np.load(data_path) 92 | 93 | image = data['image'] 94 | image_path = data['image_path'] 95 | info = data['info'] 96 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 97 | image = Image.fromarray(image) 98 | 99 | if self.transform is not None: 100 | image = self.transform(image) 101 | 102 | plane = data['plane'] 103 | num_planes = data['num_planes'][0] 104 | 105 | gt_segmentation = data['segmentation'] 106 | gt_segmentation = gt_segmentation.reshape((192, 256)) 107 | segmentation = np.zeros([21, 192, 256], dtype=np.uint8) 108 | 109 | _, h, w = segmentation.shape 110 | for i in range(num_planes + 1): 111 | # deal with backgroud 112 | if i == num_planes: 113 | seg = gt_segmentation == 20 114 | else: 115 | seg = gt_segmentation == i 116 | 117 | segmentation[i, :, :] = seg.reshape(h, w) 118 | 119 | # surface plane parameters 120 | plane_parameters, valid_region, plane_instance_parameter = \ 121 | self.get_plane_parameters(plane, num_planes, gt_segmentation) 122 | 123 | # since some depth is missing, we use plane to recover those depth following PlaneNet 124 | gt_depth = data['depth'].reshape(192, 256) 125 | depth = self.plane2depth(plane_parameters, num_planes, gt_segmentation, gt_depth).reshape(192, 256) 126 | 127 | # Depth图像需要归一化 128 | if self.transform is not None: 129 | depth = self.transform(depth) 130 | 131 | sample = { 132 | 'image': image, 133 | 'num_planes': num_planes, 134 | 'instance': torch.FloatTensor(segmentation), 135 | # one for planar and zero for non-planar 136 | 'semantic': 1 - torch.FloatTensor(segmentation[num_planes, :, :]).unsqueeze(0), 137 | 'gt_seg': torch.LongTensor(gt_segmentation), 138 | 'depth': depth.to(torch.float32), 139 | 'plane_parameters': torch.FloatTensor(plane_parameters), 140 | 'valid_region': torch.ByteTensor(valid_region.astype(np.uint8)).unsqueeze(0), 141 | 'plane_instance_parameter': torch.FloatTensor(plane_instance_parameter), 142 | 'data_path': data_path 143 | } 144 | 145 | return sample 146 | 147 | def __len__(self): 148 | return len(self.data_list) 149 | 150 | def masks_to_bboxes(self, masks): 151 | """ 152 | 从掩码张量中计算边界框的左上和右下坐标 153 | 参数: 154 | masks: 形状为 [N, H, W] 的二进制掩码张量 155 | 返回值: 156 | bounding_boxes: 形状为 [B, 4] 的边界框坐标张量,包含左上和右下坐标 157 | """ 158 | batch_size, h, w = masks.size() 159 | bounding_boxes = torch.zeros((batch_size, 4), dtype=torch.float32) 160 | 161 | for b in range(batch_size): 162 | mask = masks[b] 163 | 164 | # 找到掩码的非零元素索引 165 | nonzero_indices = torch.nonzero(mask) 166 | 167 | if nonzero_indices.size(0) == 0: 168 | # 如果掩码中没有非零元素,则边界框坐标为零 169 | assert "no mask!" 170 | else: 171 | # 计算边界框的左上和右下坐标 172 | ymin = torch.min(nonzero_indices[:, 0]) / h 173 | xmin = torch.min(nonzero_indices[:, 1]) / w 174 | ymax = torch.max(nonzero_indices[:, 0]) / h 175 | xmax = torch.max(nonzero_indices[:, 1]) / w 176 | bounding_boxes[b] = torch.tensor([(xmin + xmax) / 2, (ymin + ymax) / 2, xmax - xmin, ymax - ymin], dtype=torch.float32) 177 | 178 | return bounding_boxes -------------------------------------------------------------------------------- /PlaneSAM/backbone/resnet101_fpn_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | from .feature_pyramid_network import BackboneWithFPN, LastLevelMaxPool 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None): 14 | super().__init__() 15 | if norm_layer is None: 16 | norm_layer = nn.BatchNorm2d 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 19 | kernel_size=1, stride=1, bias=False) # squeeze channels 20 | self.bn1 = norm_layer(out_channel) 21 | # ----------------------------------------- 22 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 23 | kernel_size=3, stride=stride, bias=False, padding=1) 24 | self.bn2 = norm_layer(out_channel) 25 | # ----------------------------------------- 26 | self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, 27 | kernel_size=1, stride=1, bias=False) # unsqueeze channels 28 | self.bn3 = norm_layer(out_channel * self.expansion) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | 32 | def forward(self, x): 33 | identity = x 34 | if self.downsample is not None: 35 | identity = self.downsample(x) 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class ResNet(nn.Module): 55 | 56 | def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None): 57 | super().__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | self._norm_layer = norm_layer 61 | 62 | self.include_top = include_top 63 | self.in_channel = 64 64 | 65 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 66 | padding=3, bias=False) 67 | self.bn1 = norm_layer(self.in_channel) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 70 | self.layer1 = self._make_layer(block, 64, blocks_num[0]) 71 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 72 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 73 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 74 | if self.include_top: 75 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) 76 | self.fc = nn.Linear(512 * block.expansion, num_classes) 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 81 | 82 | def _make_layer(self, block, channel, block_num, stride=1): 83 | norm_layer = self._norm_layer 84 | downsample = None 85 | if stride != 1 or self.in_channel != channel * block.expansion: 86 | downsample = nn.Sequential( 87 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 88 | norm_layer(channel * block.expansion)) 89 | 90 | layers = [] 91 | layers.append(block(self.in_channel, channel, downsample=downsample, 92 | stride=stride, norm_layer=norm_layer)) 93 | self.in_channel = channel * block.expansion 94 | 95 | for _ in range(1, block_num): 96 | layers.append(block(self.in_channel, channel, norm_layer=norm_layer)) 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | x = self.conv1(x) 102 | x = self.bn1(x) 103 | x = self.relu(x) 104 | x = self.maxpool(x) 105 | 106 | x = self.layer1(x) 107 | x = self.layer2(x) 108 | x = self.layer3(x) 109 | x = self.layer4(x) 110 | 111 | if self.include_top: 112 | x = self.avgpool(x) 113 | x = torch.flatten(x, 1) 114 | x = self.fc(x) 115 | 116 | return x 117 | 118 | 119 | def overwrite_eps(model, eps): 120 | """ 121 | This method overwrites the default eps values of all the 122 | FrozenBatchNorm2d layers of the model with the provided value. 123 | This is necessary to address the BC-breaking change introduced 124 | by the bug-fix at pytorch/vision#2933. The overwrite is applied 125 | only when the pretrained weights are loaded to maintain compatibility 126 | with previous versions. 127 | 128 | Args: 129 | model (nn.Module): The model on which we perform the overwrite. 130 | eps (float): The new value of eps. 131 | """ 132 | for module in model.modules(): 133 | if isinstance(module, FrozenBatchNorm2d): 134 | module.eps = eps 135 | 136 | 137 | def resnet101_fpn_backbone(pretrain_path="", 138 | norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新 139 | trainable_layers=3, 140 | returned_layers=None, 141 | extra_blocks=None): 142 | """ 143 | 搭建resnet50_fpn——backbone 144 | Args: 145 | pretrain_path: resnet50的预训练权重,如果不使用就默认为空 146 | norm_layer: 官方默认的是FrozenBatchNorm2d,即不会更新参数的bn层(因为如果batch_size设置的很小会导致效果更差,还不如不用bn层) 147 | 如果自己的GPU显存很大可以设置很大的batch_size,那么自己可以传入正常的BatchNorm2d层 148 | (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) 149 | trainable_layers: 指定训练哪些层结构 150 | returned_layers: 指定哪些层的输出需要返回 151 | extra_blocks: 在输出的特征层基础上额外添加的层结构 152 | 153 | Returns: 154 | 155 | """ 156 | resnet_backbone = ResNet(Bottleneck, [3, 4, 23, 3], 157 | include_top=False, 158 | norm_layer=norm_layer) 159 | 160 | if isinstance(norm_layer, FrozenBatchNorm2d): 161 | overwrite_eps(resnet_backbone, 0.0) 162 | 163 | if pretrain_path != "": 164 | assert os.path.exists(pretrain_path), "{} is not exist.".format(pretrain_path) 165 | # 载入预训练权重 166 | print(resnet_backbone.load_state_dict(torch.load(pretrain_path), strict=False)) 167 | 168 | # select layers that wont be frozen 169 | assert 0 <= trainable_layers <= 5 170 | layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] 171 | 172 | # 如果要训练所有层结构的话,不要忘了conv1后还有一个bn1 173 | if trainable_layers == 5: 174 | layers_to_train.append("bn1") 175 | 176 | # freeze layers 177 | for name, parameter in resnet_backbone.named_parameters(): 178 | # 只训练不在layers_to_train列表中的层结构 179 | if all([not name.startswith(layer) for layer in layers_to_train]): 180 | parameter.requires_grad_(False) 181 | 182 | if extra_blocks is None: 183 | extra_blocks = LastLevelMaxPool() 184 | 185 | if returned_layers is None: 186 | returned_layers = [1, 2, 3, 4] 187 | # 返回的特征层个数肯定大于0小于5 188 | assert min(returned_layers) > 0 and max(returned_layers) < 5 189 | 190 | # return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} 191 | return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} 192 | 193 | # in_channel 为layer4的输出特征矩阵channel = 2048 194 | in_channels_stage2 = resnet_backbone.in_channel // 8 # 256 195 | # 记录resnet50提供给fpn的每个特征层channel 196 | in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] 197 | # 通过fpn后得到的每个特征层的channel 198 | out_channels = 256 199 | return BackboneWithFPN(resnet_backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) 200 | -------------------------------------------------------------------------------- /PlaneSAM/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from PlaneDataset import PlaneDataset 9 | from model import build_efficient_sam_vitt 10 | from utils.train_tools import save_ckpt, load_ckpt 11 | from utils.eval_tools import compute_iou 12 | from utils.loss import criterion 13 | from utils.make_prompt import preprocess 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | from random import randint 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Segment Any Planes') 19 | parser.add_argument('--data-dir', default='ScanNet', metavar='DIR', 20 | help='path to dataset-D') 21 | parser.add_argument('--cuda', action='store_true', default=True, 22 | help='enables CUDA training') 23 | parser.add_argument('--num-workers', default=8, type=int, metavar='N', 24 | help='number of data loading workers (default: 8)') 25 | parser.add_argument('--epochs', default=15, type=int, metavar='N', 26 | help='number of total epochs to run (default: 150)') 27 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 28 | help='manual epoch number (useful on restarts)') 29 | parser.add_argument('-b', '--batch-size', default=8, type=int, 30 | metavar='N', help='mini-batch size (default: 10)') 31 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 32 | metavar='LR', help='initial learning rate') 33 | parser.add_argument('--weight-decay', '--wd', default=0.01, type=float, 34 | metavar='W', help='weight decay') 35 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 36 | help='momentum') 37 | parser.add_argument('--save-epoch-freq', '-s', default=1, type=int, 38 | metavar='N', help='save epoch frequency (default: 5)') 39 | parser.add_argument('--last-ckpt', default='weights/pre.pth', type=str, metavar='PATH') 40 | parser.add_argument('--lr-decay-rate', default=0.8, type=float, 41 | help='decay rate of learning rate (default: 0.8)') 42 | parser.add_argument('--lr-epoch-per-decay', default=10, type=int, 43 | help='epoch of per decay of learning rate (default: 10)') 44 | parser.add_argument('--ckpt-dir', default='weights', metavar='DIR', 45 | help='path to save checkpoints') 46 | 47 | args = parser.parse_args() 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | 50 | image_w = 256 51 | image_h = 192 52 | 53 | def train(): 54 | train_data = PlaneDataset(subset="train", 55 | transform=transforms.Compose([ 56 | transforms.ToTensor()]), 57 | root_dir=args.data_dir 58 | ) 59 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, 60 | pin_memory=False) 61 | val_data = PlaneDataset(subset="val", 62 | transform=transforms.Compose([ 63 | transforms.ToTensor()]), 64 | root_dir=args.data_dir 65 | ) 66 | val_loader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=args.num_workers, 67 | pin_memory=False) 68 | 69 | num_train = len(train_data) 70 | 71 | model = build_efficient_sam_vitt() 72 | 73 | for p in model.prompt_encoder.parameters(): 74 | p.requires_grad = False 75 | 76 | if torch.cuda.device_count() >= 1: 77 | print("Let's use", torch.cuda.device_count(), "GPUs!") 78 | model = nn.DataParallel(model) 79 | 80 | model.to(device) 81 | 82 | optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr, 83 | weight_decay=args.weight_decay) 84 | 85 | global_step = 0 86 | 87 | if args.last_ckpt: 88 | global_step, _ = load_ckpt(model, None, args.last_ckpt, device, NUM_GPUS=2) 89 | 90 | # 余弦学习率调度 91 | scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0) 92 | 93 | for _ in range(args.start_epoch): 94 | optimizer.step() 95 | scheduler.step() 96 | 97 | for epoch in range(int(args.start_epoch), args.epochs): 98 | 99 | local_count = 0 100 | if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch: 101 | save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch, 102 | local_count, num_train) 103 | 104 | # 训练 105 | model.train() 106 | num_batch = 0 107 | mean_loss = 0 108 | for sample in tqdm(train_loader, desc=f"Train Epoch [{epoch + 1}/{args.epochs}]"): 109 | image = sample['image'].to(device) # [B, C, H, W] 110 | num_planes = sample['num_planes'] 111 | sel_planes = [randint(0, x - 1) for x in num_planes] 112 | target = sample['instance'].to(device) # [B, 21, H, W],第0-num_planes-1为平面掩码,第num_planes为背景(非平面) 113 | target = target[torch.arange(len(num_planes)), sel_planes, ...] 114 | depth = sample['depth'].to(device) # [B, 1, H, W] 115 | 116 | optimizer.zero_grad() 117 | 118 | input_points, input_labels = preprocess(target, num_points=0, box=True) 119 | 120 | # [B, num_prompt, per_prompt_mask, H, W] 121 | pred_masks, pred_ious = model(image, depth, input_points, input_labels) 122 | # 训练只给一个提示 123 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 124 | # 125 | pred_masks = pred_masks.view(b, -1, h, w) 126 | pred_ious = pred_ious.view(b, -1) 127 | # 遍历3个掩码 128 | pred_masks = pred_masks.permute(1, 0, 2, 3) 129 | pred_ious = pred_ious.permute(1, 0) 130 | 131 | loss = [] 132 | for mask, iou in zip(pred_masks, pred_ious): 133 | loss.append(criterion(mask, iou, target)) 134 | loss = min(loss) 135 | 136 | num_batch += 1 137 | mean_loss += loss 138 | 139 | loss.backward() 140 | optimizer.step() 141 | 142 | local_count += image.data.shape[0] 143 | global_step += 1 144 | 145 | scheduler.step() 146 | 147 | mean_loss /= num_batch 148 | print('Epoch: {} mean_loss: {}'.format(epoch + 1, mean_loss)) 149 | 150 | # 评估(use box to test) 151 | model.eval() 152 | totoal_iou = .0 153 | totoal_num = 0 154 | with torch.no_grad(): 155 | for sample in tqdm(val_loader, desc=f"Eval Epoch[{epoch + 1}/{args.epochs}]"): 156 | image = sample['image'].to(device) 157 | num_planes = sample['num_planes'][0] 158 | target = sample['instance'].to(device) 159 | target = target[:, randint(0, num_planes - 1), ...] 160 | depth = sample['depth'].to(device) 161 | 162 | # 预处理 163 | input_points, input_labels = preprocess(target, num_points=0, box=True, noise_ratio=0.1) 164 | 165 | # [B, num_prompt, per_prompt_mask, H, W] 166 | pred_masks, pred_ious = model(image, depth, input_points, input_labels) 167 | # 二值掩码 168 | pred_masks = (pred_masks > 0.).float() 169 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 170 | # [num_masks, B, H, W] 171 | pred_masks = pred_masks.view(b, -1, h, w).permute(1, 0, 2, 3) 172 | pred_ious = pred_ious.view(b, -1).permute(1, 0) 173 | 174 | # 计算iou 175 | best_id = torch.argmax(pred_ious, dim=0) 176 | best_mask = pred_masks[best_id, torch.arange(b)] 177 | best_iou = compute_iou(best_mask, target) 178 | totoal_num += len(best_iou) 179 | totoal_iou += sum(best_iou) 180 | 181 | mIou = float(totoal_iou / totoal_num) 182 | print('Epoch: {} mIou: {:.2}\n'.format(epoch + 1, mIou)) 183 | with open('log/logger.txt', 'a') as f: 184 | f.write('Epoch: {} mIou: {:.4} mean_loss: {}\n'.format(epoch + 1, mIou, mean_loss)) 185 | 186 | save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs, 187 | 0, num_train) 188 | 189 | print("Training completed ") 190 | 191 | 192 | if __name__ == '__main__': 193 | 194 | if not os.path.exists(args.ckpt_dir): 195 | os.mkdir(args.ckpt_dir) 196 | 197 | train() 198 | -------------------------------------------------------------------------------- /PlaneSAM/visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tifffile 4 | import time 5 | import numpy as np 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | from efficient_sam import build_efficient_sam_vitt 11 | # from raw_efficient_sam import build_efficient_sam_vitt 12 | from FasterRCNN.build_FasterRCNN import create_model 13 | from PlaneDataset import PlaneDataset 14 | from Nyuv2Dataset import Nyuv2Dataset, ToTensor 15 | # from S2D3DSDataset import S2d3dsDataset, ToTensor 16 | from torch.utils.data import DataLoader 17 | from utils.make_prompt import preprocess 18 | from utils.utils import load_ckpt 19 | from utils.visual_tools import map_masks_to_colors 20 | from utils.eval_tools import box_to_prompt, MatchSegmentation, evaluateMasks, match_boxes_gt 21 | from utils.box_ops import masks_to_boxes 22 | from PIL import Image 23 | from tqdm import tqdm 24 | 25 | 26 | parser = argparse.ArgumentParser(description='Segment Any Planes') 27 | parser.add_argument('--data-dir', default='mp3d-plane', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('--num-workers', default=8, type=int, metavar='N', 30 | help='number of data loading workers (default: 8)') 31 | parser.add_argument('-o', '--output', default='result', metavar='DIR', 32 | help='path to output') 33 | parser.add_argument('--cuda', action='store_true', default=True, 34 | help='enables CUDA training') 35 | parser.add_argument('--last-ckpt', default='./model/usepre_best.pth', type=str, metavar='PATH', 36 | help='path to latest checkpoint (default: none)') 37 | parser.add_argument('--detector-ckpt', default='FasterRCNN_weights/resNet101Fpn-model-9.pth', type=str, metavar='PATH', 38 | help='path to detector checkpoint (default: none)') 39 | 40 | args = parser.parse_args() 41 | device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu") 42 | 43 | 44 | def inference(): 45 | model = build_efficient_sam_vitt() 46 | detector = create_model(num_classes=2, load_pretrain_weights=False) 47 | detector.load_state_dict(torch.load(args.detector_ckpt)['model']) 48 | 49 | if torch.cuda.device_count() >= 1: 50 | print("Let's use", torch.cuda.device_count(), "GPUs!") 51 | model = nn.DataParallel(model) 52 | 53 | model.to(device) 54 | detector.to(device) 55 | 56 | # val_data = PlaneDataset(subset="train", 57 | # transform=transforms.Compose([ 58 | # transforms.ToTensor()]), 59 | # root_dir=args.data_dir 60 | # ) 61 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 62 | # pin_memory=False) 63 | 64 | val_data = Nyuv2Dataset(subset="val", 65 | transform=transforms.Compose([ 66 | ToTensor()]), 67 | datafolder=args.data_dir 68 | ) 69 | val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 70 | pin_memory=False) 71 | 72 | # val_data = S2d3dsDataset(subset="val", 73 | # transform=transforms.Compose([ 74 | # ToTensor()]), 75 | # datafolder=args.data_dir 76 | # ) 77 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 78 | # pin_memory=False) 79 | 80 | if args.last_ckpt: 81 | load_ckpt(model, None, args.last_ckpt, device) 82 | else: 83 | print('no ckpt!') 84 | 85 | model.eval() 86 | detector.eval() 87 | 88 | total_time = 0 89 | 90 | with torch.no_grad(): 91 | for batch_idx, sample in enumerate(tqdm(val_loader)): 92 | num_planes = sample['num_planes'][0] 93 | image = sample['image'].to(device) 94 | target = sample['instance'].to(device) 95 | target = target[:, :num_planes, :, :].permute(1, 0, 2, 3) # [num_planes, B, H, W] 96 | gt_boxes = masks_to_boxes(target) 97 | depth = sample['depth'].to(device) 98 | 99 | # # 预测边界框 to prompt 100 | # outputs = detector(image)[0] 101 | # boxes = outputs['boxes'] 102 | # batch_points, batch_labels = match_boxes_gt(boxes, gt_boxes) 103 | 104 | # # use preprocess 105 | # batch_points, batch_labels = preprocess(target.flatten(0, 1), num_points=0, box=True) 106 | 107 | pred_masks = [] 108 | pred_ious = [] 109 | 110 | start_time = time.time() 111 | 112 | for input_points, input_labels in zip(batch_points, batch_labels): 113 | input_points, input_labels = input_points.unsqueeze(0), input_labels.unsqueeze(0) 114 | # [B, num_prompt, per_prompt_mask, H, W] 115 | pred_mask, pred_iou = model(image, depth, input_points, input_labels) 116 | pred_masks.append(pred_mask) 117 | pred_ious.append(pred_iou) 118 | 119 | end_time = time.time() 120 | total_time += end_time - start_time 121 | 122 | # 不增加新的维度 123 | pred_masks = torch.cat(pred_masks, dim=0) 124 | pred_ious = torch.cat(pred_ious, dim=0) 125 | # 二值掩码 126 | pred_masks = (pred_masks > 0.).float() 127 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 128 | # [3, B, H, W] 129 | # [3, B] 130 | pred_masks = pred_masks.view(b, -1, h, w).permute(1, 0, 2, 3) 131 | pred_ious = pred_ious.view(b, -1).permute(1, 0) 132 | best_id = torch.argmax(pred_ious, dim=0) 133 | # [num_planes, H, W] 134 | best_mask = pred_masks[best_id, torch.arange(b)] 135 | 136 | # # 将pred与gt匹配 137 | # target = target.squeeze(1) 138 | # matching = MatchSegmentation(best_mask, target) 139 | # matched_pred_indices = [] 140 | # matched_gt_indices = [] 141 | # used = [] 142 | # for i, a in enumerate(matching): 143 | # if a not in used: 144 | # matched_pred_indices.append(i) 145 | # matched_gt_indices.append(a) 146 | # used.append(a) 147 | # matched_pred_indices = torch.as_tensor(matched_pred_indices) 148 | # matched_gt_indices = torch.as_tensor(matched_gt_indices) 149 | # prediction = torch.zeros_like(target) 150 | # prediction[matched_gt_indices] = best_mask[matched_pred_indices] 151 | 152 | # RI, VoI, SC = evaluateMasks(best_mask, target) 153 | # if SC > 0.7: 154 | # continue 155 | 156 | # visual 157 | prediction = best_mask.cpu().numpy().astype(np.uint8) 158 | # pred 159 | rgb_image = map_masks_to_colors(prediction) 160 | # gt_rgb 161 | image = image.squeeze(0).cpu().numpy() 162 | image *= 255 163 | image = image.astype(np.uint8).transpose(1, 2, 0) 164 | image = np.clip(image, 0, 255) 165 | # gt 166 | target = target.squeeze(1) 167 | target = target.cpu().numpy().astype(np.uint8) 168 | gt = map_masks_to_colors(target) 169 | # depth 170 | depth = depth.squeeze(0).squeeze(0).cpu().numpy() 171 | depth = (depth * 255 / (depth.max())).astype(np.uint8) 172 | depth = cv2.applyColorMap(depth, cv2.COLORMAP_JET) 173 | # # 拼接 174 | # rgbd_image = np.concatenate((image, depth, rgb_image, gt), axis=1) 175 | # rgbd_image = Image.fromarray(rgbd_image) 176 | # rgbd_image.save("raw_sam_result/all/" + str(batch_idx) + '.jpg') 177 | # print(f"save {batch_idx}.jpg") 178 | 179 | tifffile.imwrite("mp3d_result/input/input_" + str(batch_idx) + '.tif', image, resolution=(600, 600)) 180 | tifffile.imwrite("mp3d_result/gt/gt_" + str(batch_idx) + '.tif', gt, resolution=(600, 600)) 181 | tifffile.imwrite("mp3d_result/predict/predict_" + str(batch_idx) + '.tif', rgb_image, resolution=(600, 600)) 182 | tifffile.imwrite("mp3d_result/depth/depth_" + str(batch_idx) + '.tif', depth, resolution=(600, 600)) 183 | print(f"save {batch_idx}") 184 | 185 | print("total_time: {}".format(total_time)) 186 | 187 | if __name__ == '__main__': 188 | if not os.path.exists(args.output): 189 | os.mkdir(args.output) 190 | inference() 191 | -------------------------------------------------------------------------------- /PlaneSAM/backbone/feature_pyramid_network.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | 8 | from torch.jit.annotations import Tuple, List, Dict 9 | 10 | 11 | class IntermediateLayerGetter(nn.ModuleDict): 12 | """ 13 | Module wrapper that returns intermediate layers from a model 14 | It has a strong assumption that the modules have been registered 15 | into the model in the same order as they are used. 16 | This means that one should **not** reuse the same nn.Module 17 | twice in the forward if you want this to work. 18 | Additionally, it is only able to query submodules that are directly 19 | assigned to the model. So if `model` is passed, `model.feature1` can 20 | be returned, but not `model.feature1.layer2`. 21 | Arguments: 22 | model (nn.Module): model on which we will extract the features 23 | return_layers (Dict[name, new_name]): a dict containing the names 24 | of the modules for which the activations will be returned as 25 | the key of the dict, and the value of the dict is the name 26 | of the returned activation (which the user can specify). 27 | """ 28 | __annotations__ = { 29 | "return_layers": Dict[str, str], 30 | } 31 | 32 | def __init__(self, model, return_layers): 33 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 34 | raise ValueError("return_layers are not present in model") 35 | 36 | orig_return_layers = return_layers 37 | return_layers = {str(k): str(v) for k, v in return_layers.items()} 38 | layers = OrderedDict() 39 | 40 | # 遍历模型子模块按顺序存入有序字典 41 | # 只保存layer4及其之前的结构,舍去之后不用的结构 42 | for name, module in model.named_children(): 43 | layers[name] = module 44 | if name in return_layers: 45 | del return_layers[name] 46 | if not return_layers: 47 | break 48 | 49 | super().__init__(layers) 50 | self.return_layers = orig_return_layers 51 | 52 | def forward(self, x): 53 | out = OrderedDict() 54 | # 依次遍历模型的所有子模块,并进行正向传播, 55 | # 收集layer1, layer2, layer3, layer4的输出 56 | for name, module in self.items(): 57 | x = module(x) 58 | if name in self.return_layers: 59 | out_name = self.return_layers[name] 60 | out[out_name] = x 61 | return out 62 | 63 | 64 | class FeaturePyramidNetwork(nn.Module): 65 | """ 66 | Module that adds a FPN from on top of a set of feature maps. This is based on 67 | `"Feature Pyramid Network for Object Detection" `_. 68 | The feature maps are currently supposed to be in increasing depth 69 | order. 70 | The input to the model is expected to be an OrderedDict[Tensor], containing 71 | the feature maps on top of which the FPN will be added. 72 | Arguments: 73 | in_channels_list (list[int]): number of channels for each feature map that 74 | is passed to the module 75 | out_channels (int): number of channels of the FPN representation 76 | extra_blocks (ExtraFPNBlock or None): if provided, extra operations will 77 | be performed. It is expected to take the fpn features, the original 78 | features and the names of the original features as input, and returns 79 | a new list of feature maps and their corresponding names 80 | """ 81 | 82 | def __init__(self, in_channels_list, out_channels, extra_blocks=None): 83 | super().__init__() 84 | # 用来调整resnet特征矩阵(layer1,2,3,4)的channel(kernel_size=1) 85 | self.inner_blocks = nn.ModuleList() 86 | # 对调整后的特征矩阵使用3x3的卷积核来得到对应的预测特征矩阵 87 | self.layer_blocks = nn.ModuleList() 88 | for in_channels in in_channels_list: 89 | if in_channels == 0: 90 | continue 91 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1) 92 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1) 93 | self.inner_blocks.append(inner_block_module) 94 | self.layer_blocks.append(layer_block_module) 95 | 96 | # initialize parameters now to avoid modifying the initialization of top_blocks 97 | for m in self.children(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_uniform_(m.weight, a=1) 100 | nn.init.constant_(m.bias, 0) 101 | 102 | self.extra_blocks = extra_blocks 103 | 104 | def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: 105 | """ 106 | This is equivalent to self.inner_blocks[idx](x), 107 | but torchscript doesn't support this yet 108 | """ 109 | num_blocks = len(self.inner_blocks) 110 | if idx < 0: 111 | idx += num_blocks 112 | i = 0 113 | out = x 114 | for module in self.inner_blocks: 115 | if i == idx: 116 | out = module(x) 117 | i += 1 118 | return out 119 | 120 | def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor: 121 | """ 122 | This is equivalent to self.layer_blocks[idx](x), 123 | but torchscript doesn't support this yet 124 | """ 125 | num_blocks = len(self.layer_blocks) 126 | if idx < 0: 127 | idx += num_blocks 128 | i = 0 129 | out = x 130 | for module in self.layer_blocks: 131 | if i == idx: 132 | out = module(x) 133 | i += 1 134 | return out 135 | 136 | def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: 137 | """ 138 | Computes the FPN for a set of feature maps. 139 | Arguments: 140 | x (OrderedDict[Tensor]): feature maps for each feature level. 141 | Returns: 142 | results (OrderedDict[Tensor]): feature maps after FPN layers. 143 | They are ordered from highest resolution first. 144 | """ 145 | # unpack OrderedDict into two lists for easier handling 146 | names = list(x.keys()) 147 | x = list(x.values()) 148 | 149 | # 将resnet layer4的channel调整到指定的out_channels 150 | # last_inner = self.inner_blocks[-1](x[-1]) 151 | last_inner = self.get_result_from_inner_blocks(x[-1], -1) 152 | # result中保存着每个预测特征层 153 | results = [] 154 | # 将layer4调整channel后的特征矩阵,通过3x3卷积后得到对应的预测特征矩阵 155 | # results.append(self.layer_blocks[-1](last_inner)) 156 | results.append(self.get_result_from_layer_blocks(last_inner, -1)) 157 | 158 | for idx in range(len(x) - 2, -1, -1): 159 | inner_lateral = self.get_result_from_inner_blocks(x[idx], idx) 160 | feat_shape = inner_lateral.shape[-2:] 161 | inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest") 162 | last_inner = inner_lateral + inner_top_down 163 | results.insert(0, self.get_result_from_layer_blocks(last_inner, idx)) 164 | 165 | # 在layer4对应的预测特征层基础上生成预测特征矩阵5 166 | if self.extra_blocks is not None: 167 | results, names = self.extra_blocks(results, x, names) 168 | 169 | # make it back an OrderedDict 170 | out = OrderedDict([(k, v) for k, v in zip(names, results)]) 171 | 172 | return out 173 | 174 | 175 | class LastLevelMaxPool(torch.nn.Module): 176 | """ 177 | Applies a max_pool2d on top of the last feature map 178 | """ 179 | 180 | def forward(self, x: List[Tensor], y: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]: 181 | names.append("pool") 182 | x.append(F.max_pool2d(x[-1], 1, 2, 0)) # input, kernel_size, stride, padding 183 | return x, names 184 | 185 | 186 | class BackboneWithFPN(nn.Module): 187 | """ 188 | Adds a FPN on top of a model. 189 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to 190 | extract a submodel that returns the feature maps specified in return_layers. 191 | The same limitations of IntermediatLayerGetter apply here. 192 | Arguments: 193 | backbone (nn.Module) 194 | return_layers (Dict[name, new_name]): a dict containing the names 195 | of the modules for which the activations will be returned as 196 | the key of the dict, and the value of the dict is the name 197 | of the returned activation (which the user can specify). 198 | in_channels_list (List[int]): number of channels for each feature map 199 | that is returned, in the order they are present in the OrderedDict 200 | out_channels (int): number of channels in the FPN. 201 | extra_blocks: ExtraFPNBlock 202 | Attributes: 203 | out_channels (int): the number of channels in the FPN 204 | """ 205 | 206 | def __init__(self, 207 | backbone: nn.Module, 208 | return_layers=None, 209 | in_channels_list=None, 210 | out_channels=256, 211 | extra_blocks=None, 212 | re_getter=True): 213 | super().__init__() 214 | 215 | if extra_blocks is None: 216 | extra_blocks = LastLevelMaxPool() 217 | 218 | if re_getter is True: 219 | assert return_layers is not None 220 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 221 | else: 222 | self.body = backbone 223 | 224 | self.fpn = FeaturePyramidNetwork( 225 | in_channels_list=in_channels_list, 226 | out_channels=out_channels, 227 | extra_blocks=extra_blocks, 228 | ) 229 | 230 | self.out_channels = out_channels 231 | 232 | def forward(self, x): 233 | x = self.body(x) 234 | x = self.fpn(x) 235 | return x 236 | -------------------------------------------------------------------------------- /PlaneSAM/model/two_way_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Type 3 | import torch 4 | from torch import nn, Tensor 5 | from .mlp import MLPBlock 6 | 7 | 8 | 9 | 10 | class TwoWayTransformer(nn.Module): 11 | def __init__( 12 | self, 13 | depth: int, 14 | embedding_dim: int, 15 | num_heads: int, 16 | mlp_dim: int, 17 | activation: Type[nn.Module], 18 | normalize_before_activation: bool, 19 | attention_downsample_rate: int = 2, 20 | ) -> None: 21 | """ 22 | A transformer decoder that attends to an input image using 23 | queries whose positional embedding is supplied. 24 | 25 | Args: 26 | depth (int): number of layers in the transformer 27 | embedding_dim (int): the channel dimension for the input embeddings 28 | num_heads (int): the number of heads for multihead attention. Must 29 | divide embedding_dim 30 | mlp_dim (int): the channel dimension internal to the MLP block 31 | activation (nn.Module): the activation to use in the MLP block 32 | """ 33 | super().__init__() 34 | self.depth = depth 35 | self.embedding_dim = embedding_dim 36 | self.num_heads = num_heads 37 | self.mlp_dim = mlp_dim 38 | self.layers = nn.ModuleList() 39 | 40 | for i in range(depth): 41 | curr_layer = TwoWayAttentionBlock( 42 | embedding_dim=embedding_dim, 43 | num_heads=num_heads, 44 | mlp_dim=mlp_dim, 45 | activation=activation, 46 | normalize_before_activation=normalize_before_activation, 47 | attention_downsample_rate=attention_downsample_rate, 48 | skip_first_layer_pe=(i == 0), 49 | ) 50 | self.layers.append(curr_layer) 51 | 52 | self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( 53 | embedding_dim, 54 | num_heads, 55 | downsample_rate=attention_downsample_rate, 56 | ) 57 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 58 | 59 | def forward( 60 | self, 61 | image_embedding: Tensor, 62 | image_pe: Tensor, 63 | point_embedding: Tensor, 64 | ) -> Tuple[Tensor, Tensor]: 65 | """ 66 | Args: 67 | image_embedding (torch.Tensor): image to attend to. Should be shape 68 | B x embedding_dim x h x w for any h and w. 69 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 70 | have the same shape as image_embedding. 71 | point_embedding (torch.Tensor): the embedding to add to the query points. 72 | Must have shape B x N_points x embedding_dim for any N_points. 73 | 74 | Returns: 75 | torch.Tensor: the processed point_embedding 76 | torch.Tensor: the processed image_embedding 77 | """ 78 | 79 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 80 | bs, c, h, w = image_embedding.shape 81 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 82 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 83 | 84 | # Prepare queries 85 | queries = point_embedding 86 | keys = image_embedding 87 | 88 | # Apply transformer blocks and final layernorm 89 | for idx, layer in enumerate(self.layers): 90 | queries, keys = layer( 91 | queries=queries, 92 | keys=keys, 93 | query_pe=point_embedding, 94 | key_pe=image_pe, 95 | ) 96 | 97 | # Apply the final attention layer from the points to the image 98 | q = queries + point_embedding 99 | k = keys + image_pe 100 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 101 | queries = queries + attn_out 102 | queries = self.norm_final_attn(queries) 103 | return queries, keys 104 | 105 | 106 | class TwoWayAttentionBlock(nn.Module): 107 | def __init__( 108 | self, 109 | embedding_dim: int, 110 | num_heads: int, 111 | mlp_dim: int, 112 | activation: Type[nn.Module], 113 | normalize_before_activation: bool, 114 | attention_downsample_rate: int = 2, 115 | skip_first_layer_pe: bool = False, 116 | ) -> None: 117 | """ 118 | A transformer block with four layers: (1) self-attention of sparse 119 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 120 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 121 | inputs. 122 | 123 | Arguments: 124 | embedding_dim (int): the channel dimension of the embeddings 125 | num_heads (int): the number of heads in the attention layers 126 | mlp_dim (int): the hidden dimension of the mlp block 127 | activation (nn.Module): the activation of the mlp block 128 | skip_first_layer_pe (bool): skip the PE on the first layer 129 | """ 130 | super().__init__() 131 | self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) 132 | self.norm1 = nn.LayerNorm(embedding_dim) 133 | 134 | self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( 135 | embedding_dim, 136 | num_heads, 137 | downsample_rate=attention_downsample_rate, 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock( 142 | embedding_dim, 143 | mlp_dim, 144 | embedding_dim, 145 | 1, 146 | activation, 147 | ) 148 | 149 | self.norm3 = nn.LayerNorm(embedding_dim) 150 | 151 | self.norm4 = nn.LayerNorm(embedding_dim) 152 | self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( 153 | embedding_dim, 154 | num_heads, 155 | downsample_rate=attention_downsample_rate, 156 | ) 157 | 158 | self.skip_first_layer_pe = skip_first_layer_pe 159 | 160 | def forward( 161 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 162 | ) -> Tuple[Tensor, Tensor]: 163 | # Self attention block 164 | if not self.skip_first_layer_pe: 165 | queries = queries + query_pe 166 | attn_out = self.self_attn(q=queries, k=queries, v=queries) 167 | queries = queries + attn_out 168 | queries = self.norm1(queries) 169 | 170 | # Cross attention block, tokens attending to image embedding 171 | q = queries + query_pe 172 | k = keys + key_pe 173 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 174 | queries = queries + attn_out 175 | queries = self.norm2(queries) 176 | 177 | # MLP block 178 | mlp_out = self.mlp(queries) 179 | queries = queries + mlp_out 180 | queries = self.norm3(queries) 181 | 182 | # Cross attention block, image embedding attending to tokens 183 | q = queries + query_pe 184 | k = keys + key_pe 185 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 186 | keys = keys + attn_out 187 | keys = self.norm4(keys) 188 | 189 | return queries, keys 190 | 191 | 192 | class AttentionForTwoWayAttentionBlock(nn.Module): 193 | """ 194 | An attention layer that allows for downscaling the size of the embedding 195 | after projection to queries, keys, and values. 196 | """ 197 | 198 | def __init__( 199 | self, 200 | embedding_dim: int, 201 | num_heads: int, 202 | downsample_rate: int = 1, 203 | ) -> None: 204 | super().__init__() 205 | self.embedding_dim = embedding_dim 206 | self.internal_dim = embedding_dim // downsample_rate 207 | self.num_heads = num_heads 208 | assert ( 209 | self.internal_dim % num_heads == 0 210 | ), "num_heads must divide embedding_dim." 211 | self.c_per_head = self.internal_dim / num_heads 212 | self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head) 213 | 214 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 215 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 216 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 217 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 218 | self._reset_parameters() 219 | 220 | def _reset_parameters(self) -> None: 221 | # The fan_out is incorrect, but matches pytorch's initialization 222 | # for which qkv is a single 3*embedding_dim x embedding_dim matrix 223 | fan_in = self.embedding_dim 224 | fan_out = 3 * self.internal_dim 225 | # Xavier uniform with our custom fan_out 226 | bnd = math.sqrt(6 / (fan_in + fan_out)) 227 | nn.init.uniform_(self.q_proj.weight, -bnd, bnd) 228 | nn.init.uniform_(self.k_proj.weight, -bnd, bnd) 229 | nn.init.uniform_(self.v_proj.weight, -bnd, bnd) 230 | # out_proj.weight is left with default initialization, like pytorch attention 231 | nn.init.zeros_(self.q_proj.bias) 232 | nn.init.zeros_(self.k_proj.bias) 233 | nn.init.zeros_(self.v_proj.bias) 234 | nn.init.zeros_(self.out_proj.bias) 235 | 236 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 237 | b, n, c = x.shape 238 | x = x.reshape(b, n, num_heads, c // num_heads) 239 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 240 | 241 | def _recombine_heads(self, x: Tensor) -> Tensor: 242 | b, n_heads, n_tokens, c_per_head = x.shape 243 | x = x.transpose(1, 2) 244 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 245 | 246 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 247 | # Input projections 248 | q = self.q_proj(q) 249 | k = self.k_proj(k) 250 | v = self.v_proj(v) 251 | 252 | # Separate into heads 253 | q = self._separate_heads(q, self.num_heads) 254 | k = self._separate_heads(k, self.num_heads) 255 | v = self._separate_heads(v, self.num_heads) 256 | 257 | # Attention 258 | _, _, _, c_per_head = q.shape 259 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 260 | attn = attn * self.inv_sqrt_c_per_head 261 | attn = torch.softmax(attn, dim=-1) 262 | # Get output 263 | out = attn @ v 264 | out = self._recombine_heads(out) 265 | out = self.out_proj(out) 266 | return out 267 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Tuple, Dict, Optional 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | import torchvision 7 | 8 | from .image_list import ImageList 9 | 10 | 11 | @torch.jit.unused 12 | def _resize_image_onnx(image, self_min_size, self_max_size): 13 | # type: (Tensor, float, float) -> Tensor 14 | from torch.onnx import operators 15 | im_shape = operators.shape_as_tensor(image)[-2:] 16 | min_size = torch.min(im_shape).to(dtype=torch.float32) 17 | max_size = torch.max(im_shape).to(dtype=torch.float32) 18 | scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) 19 | 20 | image = torch.nn.functional.interpolate( 21 | image[None], scale_factor=scale_factor, mode="bilinear", recompute_scale_factor=True, 22 | align_corners=False)[0] 23 | 24 | return image 25 | 26 | 27 | def _resize_image(image, self_min_size, self_max_size): 28 | # type: (Tensor, float, float) -> Tensor 29 | im_shape = torch.tensor(image.shape[-2:]) 30 | min_size = float(torch.min(im_shape)) # 获取高宽中的最小值 31 | max_size = float(torch.max(im_shape)) # 获取高宽中的最大值 32 | scale_factor = self_min_size / min_size # 根据指定最小边长和图片最小边长计算缩放比例 33 | 34 | # 如果使用该缩放比例计算的图片最大边长大于指定的最大边长 35 | if max_size * scale_factor > self_max_size: 36 | scale_factor = self_max_size / max_size # 将缩放比例设为指定最大边长和图片最大边长之比 37 | 38 | # interpolate利用插值的方法缩放图片 39 | # image[None]操作是在最前面添加batch维度[C, H, W] -> [1, C, H, W] 40 | # bilinear只支持4D Tensor 41 | image = torch.nn.functional.interpolate( 42 | image[None], scale_factor=scale_factor, mode="bilinear", recompute_scale_factor=True, 43 | align_corners=False)[0] 44 | 45 | return image 46 | 47 | 48 | class GeneralizedRCNNTransform(nn.Module): 49 | """ 50 | Performs input / target transformation before feeding the data to a GeneralizedRCNN 51 | model. 52 | 53 | The transformations it perform are: 54 | - input normalization (mean subtraction and std division) 55 | - input / target resizing to match min_size / max_size 56 | 57 | It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets 58 | """ 59 | 60 | def __init__(self, min_size, max_size, image_mean, image_std): 61 | super(GeneralizedRCNNTransform, self).__init__() 62 | if not isinstance(min_size, (list, tuple)): 63 | min_size = (min_size,) 64 | self.min_size = min_size # 指定图像的最小边长范围 65 | self.max_size = max_size # 指定图像的最大边长范围 66 | self.image_mean = image_mean # 指定图像在标准化处理中的均值 67 | self.image_std = image_std # 指定图像在标准化处理中的方差 68 | 69 | def normalize(self, image): 70 | """标准化处理""" 71 | dtype, device = image.dtype, image.device 72 | mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) 73 | std = torch.as_tensor(self.image_std, dtype=dtype, device=device) 74 | # [:, None, None]: shape [3] -> [3, 1, 1] 75 | return (image - mean[:, None, None]) / std[:, None, None] 76 | 77 | def torch_choice(self, k): 78 | # type: (List[int]) -> int 79 | """ 80 | Implements `random.choice` via torch ops so it can be compiled with 81 | TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 82 | is fixed. 83 | """ 84 | index = int(torch.empty(1).uniform_(0., float(len(k))).item()) 85 | return k[index] 86 | 87 | def resize(self, image, target): 88 | # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] 89 | """ 90 | 将图片缩放到指定的大小范围内,并对应缩放bboxes信息 91 | Args: 92 | image: 输入的图片 93 | target: 输入图片的相关信息(包括bboxes信息) 94 | 95 | Returns: 96 | image: 缩放后的图片 97 | target: 缩放bboxes后的图片相关信息 98 | """ 99 | # image shape is [channel, height, width] 100 | h, w = image.shape[-2:] 101 | 102 | if self.training: 103 | size = float(self.torch_choice(self.min_size)) # 指定输入图片的最小边长,注意是self.min_size不是min_size 104 | else: 105 | # FIXME assume for now that testing uses the largest scale 106 | size = float(self.min_size[-1]) # 指定输入图片的最小边长,注意是self.min_size不是min_size 107 | 108 | if torchvision._is_tracing(): 109 | image = _resize_image_onnx(image, size, float(self.max_size)) 110 | else: 111 | image = _resize_image(image, size, float(self.max_size)) 112 | 113 | if target is None: 114 | return image, target 115 | 116 | bbox = target["boxes"] 117 | # 根据图像的缩放比例来缩放bbox 118 | bbox = resize_boxes(bbox, [h, w], image.shape[-2:]) 119 | target["boxes"] = bbox 120 | 121 | return image, target 122 | 123 | # _onnx_batch_images() is an implementation of 124 | # batch_images() that is supported by ONNX tracing. 125 | @torch.jit.unused 126 | def _onnx_batch_images(self, images, size_divisible=32): 127 | # type: (List[Tensor], int) -> Tensor 128 | max_size = [] 129 | for i in range(images[0].dim()): 130 | max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) 131 | max_size.append(max_size_i) 132 | stride = size_divisible 133 | max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64) 134 | max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64) 135 | max_size = tuple(max_size) 136 | 137 | # work around for 138 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 139 | # which is not yet supported in onnx 140 | padded_imgs = [] 141 | for img in images: 142 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 143 | padded_img = torch.nn.functional.pad(img, [0, padding[2], 0, padding[1], 0, padding[0]]) 144 | padded_imgs.append(padded_img) 145 | 146 | return torch.stack(padded_imgs) 147 | 148 | def max_by_axis(self, the_list): 149 | # type: (List[List[int]]) -> List[int] 150 | maxes = the_list[0] 151 | for sublist in the_list[1:]: 152 | for index, item in enumerate(sublist): 153 | maxes[index] = max(maxes[index], item) 154 | return maxes 155 | 156 | def batch_images(self, images, size_divisible=32): 157 | # type: (List[Tensor], int) -> Tensor 158 | """ 159 | 将一批图像打包成一个batch返回(注意batch中每个tensor的shape是相同的) 160 | Args: 161 | images: 输入的一批图片 162 | size_divisible: 将图像高和宽调整到该数的整数倍 163 | 164 | Returns: 165 | batched_imgs: 打包成一个batch后的tensor数据 166 | """ 167 | 168 | if torchvision._is_tracing(): 169 | # batch_images() does not export well to ONNX 170 | # call _onnx_batch_images() instead 171 | return self._onnx_batch_images(images, size_divisible) 172 | 173 | # 分别计算一个batch中所有图片中的最大channel, height, width 174 | max_size = self.max_by_axis([list(img.shape) for img in images]) 175 | 176 | stride = float(size_divisible) 177 | # max_size = list(max_size) 178 | # 将height向上调整到stride的整数倍 179 | max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) 180 | # 将width向上调整到stride的整数倍 181 | max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) 182 | 183 | # [batch, channel, height, width] 184 | batch_shape = [len(images)] + max_size 185 | 186 | # 创建shape为batch_shape且值全部为0的tensor 187 | batched_imgs = images[0].new_full(batch_shape, 0) 188 | for img, pad_img in zip(images, batched_imgs): 189 | # 将输入images中的每张图片复制到新的batched_imgs的每张图片中,对齐左上角,保证bboxes的坐标不变 190 | # 这样保证输入到网络中一个batch的每张图片的shape相同 191 | # copy_: Copies the elements from src into self tensor and returns self 192 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 193 | 194 | return batched_imgs 195 | 196 | def postprocess(self, 197 | result, # type: List[Dict[str, Tensor]] 198 | image_shapes, # type: List[Tuple[int, int]] 199 | original_image_sizes # type: List[Tuple[int, int]] 200 | ): 201 | # type: (...) -> List[Dict[str, Tensor]] 202 | """ 203 | 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上) 204 | Args: 205 | result: list(dict), 网络的预测结果, len(result) == batch_size 206 | image_shapes: list(torch.Size), 图像预处理缩放后的尺寸, len(image_shapes) == batch_size 207 | original_image_sizes: list(torch.Size), 图像的原始尺寸, len(original_image_sizes) == batch_size 208 | 209 | Returns: 210 | 211 | """ 212 | if self.training: 213 | return result 214 | 215 | # 遍历每张图片的预测信息,将boxes信息还原回原尺度 216 | for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): 217 | boxes = pred["boxes"] 218 | boxes = resize_boxes(boxes, im_s, o_im_s) # 将bboxes缩放回原图像尺度上 219 | result[i]["boxes"] = boxes 220 | return result 221 | 222 | def __repr__(self): 223 | """自定义输出实例化对象的信息,可通过print打印实例信息""" 224 | format_string = self.__class__.__name__ + '(' 225 | _indent = '\n ' 226 | format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) 227 | format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size, 228 | self.max_size) 229 | format_string += '\n)' 230 | return format_string 231 | 232 | def forward(self, 233 | images, # type: List[Tensor] 234 | targets=None # type: Optional[List[Dict[str, Tensor]]] 235 | ): 236 | # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]] 237 | images = [img for img in images] 238 | for i in range(len(images)): 239 | image = images[i] 240 | target_index = targets[i] if targets is not None else None 241 | 242 | if image.dim() != 3: 243 | raise ValueError("images is expected to be a list of 3d tensors " 244 | "of shape [C, H, W], got {}".format(image.shape)) 245 | image = self.normalize(image) # 对图像进行标准化处理 246 | image, target_index = self.resize(image, target_index) # 对图像和对应的bboxes缩放到指定范围 247 | images[i] = image 248 | if targets is not None and target_index is not None: 249 | targets[i] = target_index 250 | 251 | # 记录resize后的图像尺寸 252 | image_sizes = [img.shape[-2:] for img in images] 253 | images = self.batch_images(images) # 将images打包成一个batch 254 | image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], []) 255 | 256 | for image_size in image_sizes: 257 | assert len(image_size) == 2 258 | image_sizes_list.append((image_size[0], image_size[1])) 259 | 260 | image_list = ImageList(images, image_sizes_list) 261 | return image_list, targets 262 | 263 | 264 | def resize_boxes(boxes, original_size, new_size): 265 | # type: (Tensor, List[int], List[int]) -> Tensor 266 | """ 267 | 将boxes参数根据图像的缩放情况进行相应缩放 268 | 269 | Arguments: 270 | original_size: 图像缩放前的尺寸 271 | new_size: 图像缩放后的尺寸 272 | """ 273 | ratios = [ 274 | torch.tensor(s, dtype=torch.float32, device=boxes.device) / 275 | torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) 276 | for s, s_orig in zip(new_size, original_size) 277 | ] 278 | ratios_height, ratios_width = ratios 279 | # Removes a tensor dimension, boxes [minibatch, 4] 280 | # Returns a tuple of all slices along a given dimension, already without it. 281 | xmin, ymin, xmax, ymax = boxes.unbind(1) 282 | xmin = xmin * ratios_width 283 | xmax = xmax * ratios_width 284 | ymin = ymin * ratios_height 285 | ymax = ymax * ratios_height 286 | return torch.stack((xmin, ymin, xmax, ymax), dim=1) 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /PlaneSAM/model/efficient_sam_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import List, Optional, Tuple, Type 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class LayerNorm2d(nn.Module): 16 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 17 | super().__init__() 18 | self.weight = nn.Parameter(torch.ones(num_channels)) 19 | self.bias = nn.Parameter(torch.zeros(num_channels)) 20 | self.eps = eps 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | u = x.mean(1, keepdim=True) 24 | s = (x - u).pow(2).mean(1, keepdim=True) 25 | x = (x - u) / torch.sqrt(s + self.eps) 26 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 27 | return x 28 | 29 | 30 | class PatchEmbed(nn.Module): 31 | """2D Image to Patch Embedding""" 32 | 33 | def __init__( 34 | self, 35 | img_size, 36 | patch_size, 37 | in_chans, 38 | embed_dim, 39 | ): 40 | super().__init__() 41 | self.proj = nn.Conv2d( 42 | in_chans, 43 | embed_dim, 44 | kernel_size=(patch_size, patch_size), 45 | stride=(patch_size, patch_size), 46 | bias=True, 47 | ) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | x = self.proj(x) 52 | return x 53 | 54 | 55 | # # 可学习的tokens[B, 4096, 32] 56 | # class MFA(nn.Module): 57 | # 58 | # def __init__(self, 59 | # theta, 60 | # patch_embed_dim, 61 | # num_patches, 62 | # num_heads, 63 | # qk_scale=None, 64 | # token_dim=32 65 | # ): 66 | # super().__init__() 67 | # self.num_heads = num_heads 68 | # self.hidden_dim = patch_embed_dim // theta # 6 69 | # head_dim = self.hidden_dim // num_heads # 2 70 | # self.scale = qk_scale or head_dim ** -0.5 71 | # self.tokens = nn.Parameter(torch.randn(1, num_patches * num_patches, token_dim), requires_grad=True) 72 | # self.down1 = nn.Linear(token_dim, self.hidden_dim) 73 | # self.down2 = nn.Linear(token_dim, self.hidden_dim) 74 | # self.down3 = nn.Linear(patch_embed_dim, self.hidden_dim) 75 | # self.LN = nn.LayerNorm(token_dim, eps=1e-6) 76 | # self.up = nn.Linear(self.hidden_dim, patch_embed_dim) 77 | # 78 | # def forward(self, q: torch.Tensor): 79 | # B, N, C = q.shape 80 | # tokens = self.tokens.expand(B, N, 32) 81 | # tokens = self.LN(tokens) 82 | # v = self.down1(tokens).reshape(B, N, self.num_heads, self.hidden_dim // self.num_heads).permute(0, 2, 1, 3) 83 | # k = self.down2(tokens).reshape(B, N, self.num_heads, self.hidden_dim // self.num_heads).permute(0, 2, 1, 3) 84 | # q = self.down3(q).reshape(B, N, self.num_heads, self.hidden_dim // self.num_heads).permute(0, 2, 1, 3) 85 | # attn = (q @ k.transpose(-2, -1)) * self.scale 86 | # attn = attn.softmax(dim=-1) 87 | # q = (attn @ v).transpose(1, 2).reshape(B, N, self.hidden_dim) 88 | # q = self.up(q) 89 | # 90 | # return q 91 | 92 | 93 | class Attention(nn.Module): 94 | def __init__( 95 | self, 96 | dim, 97 | num_heads, 98 | qkv_bias, 99 | qk_scale=None, 100 | ): 101 | super().__init__() 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | self.scale = qk_scale or head_dim**-0.5 105 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 106 | self.proj = nn.Linear(dim, dim) 107 | # self.MFA = MFA(32, dim, 64, num_heads) 108 | 109 | def forward(self, x): 110 | B, N, C = x.shape 111 | qkv = ( 112 | self.qkv(x) 113 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 114 | .permute(2, 0, 3, 1, 4) 115 | ) 116 | # [B, num_heads, N, C // num_heads] 117 | q, k, v = ( 118 | qkv[0], 119 | qkv[1], 120 | qkv[2], 121 | ) 122 | # x_r = q.permute(0, 2, 1, 3).reshape(B, N, C) 123 | # x_r = self.MFA(x_r) 124 | attn = (q @ k.transpose(-2, -1)) * self.scale 125 | attn = attn.softmax(dim=-1) 126 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 127 | x = self.proj(x) 128 | # x = x + x_r 129 | return x 130 | 131 | 132 | class Mlp(nn.Module): 133 | def __init__( 134 | self, 135 | in_features, 136 | hidden_features=None, 137 | out_features=None, 138 | act_layer=nn.GELU, 139 | ): 140 | super().__init__() 141 | out_features = out_features or in_features 142 | hidden_features = hidden_features or in_features 143 | self.fc1 = nn.Linear(in_features, hidden_features) 144 | self.act = act_layer() 145 | self.fc2 = nn.Linear(hidden_features, out_features) 146 | 147 | def forward(self, x): 148 | x = self.fc1(x) 149 | x = self.act(x) 150 | x = self.fc2(x) 151 | return x 152 | 153 | 154 | class Block(nn.Module): 155 | def __init__( 156 | self, 157 | dim, 158 | num_heads, 159 | mlp_ratio=4.0, 160 | qkv_bias=False, 161 | qk_scale=None, 162 | act_layer=nn.GELU, 163 | ): 164 | super().__init__() 165 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 166 | self.attn = Attention( 167 | dim, 168 | num_heads=num_heads, 169 | qkv_bias=qkv_bias, 170 | qk_scale=qk_scale, 171 | ) 172 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 173 | mlp_hidden_dim = int(dim * mlp_ratio) 174 | self.mlp = Mlp( 175 | in_features=dim, 176 | hidden_features=mlp_hidden_dim, 177 | act_layer=act_layer, 178 | ) 179 | 180 | def forward(self, x): 181 | x = x + self.attn(self.norm1(x)) 182 | x = x + self.mlp(self.norm2(x)) 183 | return x 184 | 185 | 186 | @torch.jit.export 187 | def get_abs_pos( 188 | abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int] 189 | ) -> torch.Tensor: 190 | """ 191 | Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token 192 | dimension for the original embeddings. 193 | Args: 194 | abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). 195 | has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. 196 | hw (Tuple): size of input image tokens. 197 | 198 | Returns: 199 | Absolute positional embeddings after processing with shape (1, H, W, C) 200 | """ 201 | h = hw[0] 202 | w = hw[1] 203 | if has_cls_token: 204 | abs_pos = abs_pos[:, 1:] 205 | xy_num = abs_pos.shape[1] 206 | size = int(math.sqrt(xy_num)) 207 | assert size * size == xy_num 208 | 209 | if size != h or size != w: 210 | new_abs_pos = F.interpolate( 211 | abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), 212 | size=(h, w), 213 | mode="bicubic", 214 | align_corners=False, 215 | ) 216 | return new_abs_pos.permute(0, 2, 3, 1) 217 | else: 218 | return abs_pos.reshape(1, h, w, -1) 219 | 220 | 221 | class MPG(nn.Module): 222 | 223 | def __init__( 224 | self, 225 | beta, 226 | in_chans, 227 | kernel_size, 228 | use_patch_embed=True, 229 | ): 230 | super().__init__() 231 | self.patch_embed = None 232 | if use_patch_embed: 233 | self.patch_embed = nn.Conv2d( 234 | in_chans, 235 | in_chans, 236 | kernel_size=kernel_size, 237 | stride=1, 238 | padding=1, 239 | bias=True, 240 | ) 241 | hidden_chans = int(in_chans / beta) 242 | self.down1 = nn.Linear(in_chans, hidden_chans) 243 | self.down2 = nn.Linear(in_chans, hidden_chans) 244 | self.up = nn.Linear(hidden_chans, in_chans) 245 | 246 | def forward(self, x, p, num_patches): 247 | """ 248 | :param x: RGB嵌入[B, N, C] 249 | :param p: D嵌入[B, N, C] 250 | :return: [B, N, C] 251 | """ 252 | if self.patch_embed: 253 | b, n, c = p.shape 254 | p = p.permute(0, 2, 1).reshape(b, c, num_patches, -1) 255 | p = self.patch_embed(p) 256 | p = p.reshape(b, c, -1).permute(0, 2, 1) 257 | 258 | x = self.down1(x) 259 | p = self.down2(p) 260 | p = self.up(x + p) 261 | 262 | return p 263 | 264 | 265 | # Image encoder for efficient SAM. 266 | class ImageEncoderViT(nn.Module): 267 | def __init__( 268 | self, 269 | img_size: int, 270 | patch_size: int, 271 | in_chans: int, 272 | patch_embed_dim: int, 273 | normalization_type: str, 274 | depth: int, 275 | num_heads: int, 276 | mlp_ratio: float, 277 | neck_dims: List[int], 278 | act_layer: Type[nn.Module], 279 | ) -> None: 280 | """ 281 | Args: 282 | img_size (int): Input image size. 283 | patch_size (int): Patch size. 284 | in_chans (int): Number of input image channels. 285 | patch_embed_dim (int): Patch embedding dimension. 286 | depth (int): Depth of ViT. 287 | num_heads (int): Number of attention heads in each ViT block. 288 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 289 | act_layer (nn.Module): Activation layer. 290 | """ 291 | super().__init__() 292 | 293 | self.img_size = img_size 294 | self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1)) 295 | self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1] 296 | self.pretrain_use_cls_token = True 297 | pretrain_img_size = 224 298 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim) 299 | self.patch_embed_D = PatchEmbed(img_size, patch_size, 1, patch_embed_dim) 300 | # Initialize absolute positional embedding with pretrain image size. 301 | num_patches = (pretrain_img_size // patch_size) * ( 302 | pretrain_img_size // patch_size 303 | ) 304 | num_positions = num_patches + 1 305 | self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim)) 306 | self.pos_embed_D = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim)) 307 | self.blocks = nn.ModuleList() 308 | for i in range(depth): 309 | vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True) 310 | self.blocks.append(vit_block) 311 | self.MPGs = nn.ModuleList() 312 | for i in range(depth): 313 | # 第一个MPG是没有嵌入的 314 | mpg_block = MPG(4, patch_embed_dim, 3, bool(i)) 315 | self.MPGs.append(mpg_block) 316 | self.neck = nn.Sequential( 317 | nn.Conv2d( 318 | patch_embed_dim, 319 | neck_dims[0], 320 | kernel_size=1, 321 | bias=False, 322 | ), 323 | LayerNorm2d(neck_dims[0]), 324 | nn.Conv2d( 325 | neck_dims[0], 326 | neck_dims[0], 327 | kernel_size=3, 328 | padding=1, 329 | bias=False, 330 | ), 331 | LayerNorm2d(neck_dims[0]), 332 | ) 333 | 334 | def forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: 335 | assert ( 336 | x.shape[2] == self.img_size and x.shape[3] == self.img_size 337 | ), "input image size must match self.img_size" 338 | x = self.patch_embed(x) 339 | p = self.patch_embed_D(p) 340 | # B C H W -> B H W C 341 | x = x.permute(0, 2, 3, 1) 342 | x = x + get_abs_pos( 343 | self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]] 344 | ) 345 | p = p.permute(0, 2, 3, 1) 346 | p = p + get_abs_pos( 347 | self.pos_embed_D, self.pretrain_use_cls_token, [p.shape[1], p.shape[2]]) 348 | num_patches = x.shape[1] 349 | assert x.shape[2] == num_patches 350 | x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3]) 351 | p = p.reshape(p.shape[0], num_patches * num_patches, p.shape[3]) 352 | for blk, mpg in zip(self.blocks, self.MPGs): 353 | p = mpg(x, p, num_patches) 354 | x = x + p 355 | x = blk(x) 356 | x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2]) 357 | x = self.neck(x.permute(0, 3, 1, 2)) 358 | return x 359 | -------------------------------------------------------------------------------- /PlaneSAM/model/efficient_sam_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Tuple, Type 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .mlp import MLPBlock 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | ) -> None: 24 | """ 25 | Encodes prompts for input to SAM's mask decoder. 26 | 27 | Arguments: 28 | embed_dim (int): The prompts' embedding dimension 29 | image_embedding_size (tuple(int, int)): The spatial size of the 30 | image embedding, as (H, W). 31 | input_image_size (int): The padded size of the image as input 32 | to the image encoder, as (H, W). 33 | """ 34 | super().__init__() 35 | self.embed_dim = embed_dim 36 | self.input_image_size = input_image_size 37 | self.image_embedding_size = image_embedding_size 38 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 39 | self.invalid_points = nn.Embedding(1, embed_dim) 40 | self.point_embeddings = nn.Embedding(1, embed_dim) 41 | self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim) 42 | self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim) 43 | 44 | def get_dense_pe(self) -> torch.Tensor: 45 | """ 46 | Returns the positional encoding used to encode point prompts, 47 | applied to a dense set of points the shape of the image encoding. 48 | 49 | Returns: 50 | torch.Tensor: Positional encoding with shape 51 | 1x(embed_dim)x(embedding_h)x(embedding_w) 52 | """ 53 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 54 | 55 | def _embed_points( 56 | self, 57 | points: torch.Tensor, 58 | labels: torch.Tensor, 59 | ) -> torch.Tensor: 60 | """Embeds point prompts.""" 61 | points = points + 0.5 # Shift to center of pixel 62 | point_embedding = self.pe_layer.forward_with_coords( 63 | points, self.input_image_size 64 | ) 65 | invalid_label_ids = torch.eq(labels, -1)[:,:,None] 66 | point_label_ids = torch.eq(labels, 1)[:,:,None] 67 | topleft_label_ids = torch.eq(labels, 2)[:,:,None] 68 | bottomright_label_ids = torch.eq(labels, 3)[:,:,None] 69 | point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids 70 | point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids 71 | point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids 72 | point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids 73 | return point_embedding 74 | 75 | def forward( 76 | self, 77 | coords, 78 | labels, 79 | ) -> torch.Tensor: 80 | """ 81 | Embeds different types of prompts, returning both sparse and dense 82 | embeddings. 83 | 84 | Arguments: 85 | points: A tensor of shape [B, 2] 86 | labels: An integer tensor of shape [B] where each element is 1,2 or 3. 87 | 88 | Returns: 89 | torch.Tensor: sparse embeddings for the points and boxes, with shape 90 | BxNx(embed_dim), where N is determined by the number of input points 91 | and boxes. 92 | """ 93 | return self._embed_points(coords, labels) 94 | 95 | 96 | class PositionEmbeddingRandom(nn.Module): 97 | """ 98 | Positional encoding using random spatial frequencies. 99 | """ 100 | 101 | def __init__(self, num_pos_feats: int) -> None: 102 | super().__init__() 103 | self.register_buffer( 104 | "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats)) 105 | ) 106 | 107 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 108 | """Positionally encode points that are normalized to [0,1].""" 109 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 110 | coords = 2 * coords - 1 111 | coords = coords @ self.positional_encoding_gaussian_matrix 112 | coords = 2 * np.pi * coords 113 | # outputs d_1 x ... x d_n x C shape 114 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 115 | 116 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 117 | """Generate positional encoding for a grid of the specified size.""" 118 | h, w = size 119 | device = self.positional_encoding_gaussian_matrix.device 120 | grid = torch.ones([h, w], device=device, dtype=torch.float32) 121 | y_embed = grid.cumsum(dim=0) - 0.5 122 | x_embed = grid.cumsum(dim=1) - 0.5 123 | y_embed = y_embed / h 124 | x_embed = x_embed / w 125 | 126 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 127 | return pe.permute(2, 0, 1) # C x H x W 128 | 129 | def forward_with_coords( 130 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 131 | ) -> torch.Tensor: 132 | """Positionally encode points that are not normalized to [0,1].""" 133 | coords = coords_input.clone() 134 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 135 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 136 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 137 | 138 | 139 | class MaskDecoder(nn.Module): 140 | def __init__( 141 | self, 142 | *, 143 | transformer_dim: int, 144 | transformer: nn.Module, 145 | num_multimask_outputs: int, 146 | activation: Type[nn.Module], 147 | normalization_type: str, 148 | normalize_before_activation: bool, 149 | iou_head_depth: int, 150 | iou_head_hidden_dim: int, 151 | upscaling_layer_dims: List[int], 152 | ) -> None: 153 | """ 154 | Predicts masks given an image and prompt embeddings, using a 155 | transformer architecture. 156 | 157 | Arguments: 158 | transformer_dim (int): the channel dimension of the transformer 159 | transformer (nn.Module): the transformer used to predict masks 160 | num_multimask_outputs (int): the number of masks to predict 161 | when disambiguating masks 162 | activation (nn.Module): the type of activation to use when 163 | upscaling masks 164 | iou_head_depth (int): the depth of the MLP used to predict 165 | mask quality 166 | iou_head_hidden_dim (int): the hidden dimension of the MLP 167 | used to predict mask quality 168 | """ 169 | super().__init__() 170 | self.transformer_dim = transformer_dim 171 | self.transformer = transformer 172 | 173 | self.num_multimask_outputs = num_multimask_outputs 174 | 175 | self.iou_token = nn.Embedding(1, transformer_dim) 176 | if num_multimask_outputs > 1: 177 | self.num_mask_tokens = num_multimask_outputs + 1 178 | else: 179 | self.num_mask_tokens = 1 180 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 181 | output_dim_after_upscaling = transformer_dim 182 | 183 | self.final_output_upscaling_layers = nn.ModuleList([]) 184 | for idx, layer_dims in enumerate(upscaling_layer_dims): 185 | self.final_output_upscaling_layers.append( 186 | nn.Sequential( 187 | nn.ConvTranspose2d( 188 | output_dim_after_upscaling, 189 | layer_dims, 190 | kernel_size=2, 191 | stride=2, 192 | ), 193 | nn.GroupNorm(1, layer_dims) 194 | if idx < len(upscaling_layer_dims) - 1 195 | else nn.Identity(), 196 | activation(), 197 | ) 198 | ) 199 | output_dim_after_upscaling = layer_dims 200 | 201 | self.output_hypernetworks_mlps = nn.ModuleList( 202 | [ 203 | MLPBlock( 204 | input_dim=transformer_dim, 205 | hidden_dim=transformer_dim, 206 | output_dim=output_dim_after_upscaling, 207 | num_layers=2, 208 | act=activation, 209 | ) 210 | for i in range(self.num_mask_tokens) 211 | ] 212 | ) 213 | 214 | self.iou_prediction_head = MLPBlock( 215 | input_dim=transformer_dim, 216 | hidden_dim=iou_head_hidden_dim, 217 | output_dim=self.num_mask_tokens, 218 | num_layers=iou_head_depth, 219 | act=activation, 220 | ) 221 | 222 | def forward( 223 | self, 224 | image_embeddings: torch.Tensor, 225 | image_pe: torch.Tensor, 226 | sparse_prompt_embeddings: torch.Tensor, 227 | multimask_output: bool, 228 | ) -> Tuple[torch.Tensor, torch.Tensor]: 229 | """ 230 | Predict masks given image and prompt embeddings. 231 | 232 | Arguments: 233 | image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] 234 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable). 235 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 236 | multimask_output (bool): Whether to return multiple masks or a single 237 | mask. 238 | 239 | Returns: 240 | torch.Tensor: batched predicted masks 241 | torch.Tensor: batched predictions of mask quality 242 | """ 243 | 244 | ( 245 | batch_size, 246 | max_num_queries, 247 | sparse_embed_dim_1, 248 | sparse_embed_dim_2, 249 | ) = sparse_prompt_embeddings.shape 250 | 251 | ( 252 | _, 253 | image_embed_dim_c, 254 | image_embed_dim_h, 255 | image_embed_dim_w, 256 | ) = image_embeddings.shape 257 | 258 | # Tile the image embedding for all queries. 259 | image_embeddings_tiled = torch.tile( 260 | image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1] 261 | ).view( 262 | batch_size * max_num_queries, 263 | image_embed_dim_c, 264 | image_embed_dim_h, 265 | image_embed_dim_w, 266 | ) 267 | sparse_prompt_embeddings = sparse_prompt_embeddings.reshape( 268 | batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2 269 | ) 270 | masks, iou_pred = self.predict_masks( 271 | image_embeddings=image_embeddings_tiled, 272 | image_pe=image_pe, 273 | sparse_prompt_embeddings=sparse_prompt_embeddings, 274 | ) 275 | if multimask_output and self.num_multimask_outputs > 1: 276 | return masks[:, 1:, :], iou_pred[:, 1:] 277 | else: 278 | return masks[:, :1, :], iou_pred[:, :1] 279 | 280 | def predict_masks( 281 | self, 282 | image_embeddings: torch.Tensor, 283 | image_pe: torch.Tensor, 284 | sparse_prompt_embeddings: torch.Tensor, 285 | ) -> Tuple[torch.Tensor, torch.Tensor]: 286 | """Predicts masks. See 'forward' for more details.""" 287 | # Concatenate output tokens 288 | output_tokens = torch.cat( 289 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 290 | ) 291 | output_tokens = output_tokens.unsqueeze(0).expand( 292 | sparse_prompt_embeddings.size(0), -1, -1 293 | ) 294 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 295 | # Expand per-image data in batch direction to be per-mask 296 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 297 | b, c, h, w = image_embeddings.shape 298 | hs, src = self.transformer(image_embeddings, pos_src, tokens) 299 | iou_token_out = hs[:, 0, :] 300 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 301 | 302 | # Upscale mask embeddings and predict masks using the mask tokens 303 | upscaled_embedding = src.transpose(1, 2).view(b, c, h, w) 304 | 305 | for upscaling_layer in self.final_output_upscaling_layers: 306 | upscaled_embedding = upscaling_layer(upscaled_embedding) 307 | hyper_in_list: List[torch.Tensor] = [] 308 | for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps): 309 | hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :])) 310 | hyper_in = torch.stack(hyper_in_list, dim=1) 311 | b, c, h, w = upscaled_embedding.shape 312 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 313 | # Generate mask quality predictions 314 | iou_pred = self.iou_prediction_head(iou_token_out) 315 | return masks, iou_pred 316 | -------------------------------------------------------------------------------- /PlaneSAM/model/efficient_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, List, Tuple, Type 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from torch import nn, Tensor 14 | 15 | from .efficient_sam_decoder import MaskDecoder, PromptEncoder 16 | from .efficient_sam_encoder import ImageEncoderViT 17 | from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer 18 | 19 | class EfficientSam(nn.Module): 20 | mask_threshold: float = 0.0 21 | image_format: str = "RGB" 22 | 23 | def __init__( 24 | self, 25 | image_encoder: ImageEncoderViT, 26 | prompt_encoder: PromptEncoder, 27 | decoder_max_num_input_points: int, 28 | mask_decoder: MaskDecoder, 29 | pixel_mean: List[float] = [0.485, 0.456, 0.406, 0.473], 30 | pixel_std: List[float] = [0.229, 0.224, 0.225, 0.189], 31 | ) -> None: 32 | """ 33 | SAM predicts object masks from an image and input prompts. 34 | 35 | Arguments: 36 | image_encoder (ImageEncoderViT): The backbone used to encode the 37 | image into image embeddings that allow for efficient mask prediction. 38 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 39 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 40 | and encoded prompts. 41 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 42 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 43 | """ 44 | super().__init__() 45 | self.image_encoder = image_encoder 46 | self.prompt_encoder = prompt_encoder 47 | self.decoder_max_num_input_points = decoder_max_num_input_points 48 | self.mask_decoder = mask_decoder 49 | self.register_buffer( 50 | "pixel_mean", torch.Tensor(pixel_mean).view(1, 4, 1, 1), False 51 | ) 52 | self.register_buffer( 53 | "pixel_std", torch.Tensor(pixel_std).view(1, 4, 1, 1), False 54 | ) 55 | 56 | @torch.jit.export 57 | def predict_masks( 58 | self, 59 | image_embeddings: torch.Tensor, 60 | batched_points: torch.Tensor, 61 | batched_point_labels: torch.Tensor, 62 | multimask_output: bool, 63 | input_h: int, 64 | input_w: int, 65 | output_h: int = -1, 66 | output_w: int = -1, 67 | ) -> Tuple[torch.Tensor, torch.Tensor]: 68 | """ 69 | Predicts masks given image embeddings and prompts. This only runs the decoder. 70 | 71 | Arguments: 72 | image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] 73 | batched_points: A tensor of shape [B, max_num_queries, num_pts, 2] 74 | batched_point_labels: A tensor of shape [B, max_num_queries, num_pts] 75 | Returns: 76 | A tuple of two tensors: 77 | low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks 78 | iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores 79 | """ 80 | 81 | batch_size, max_num_queries, num_pts, _ = batched_points.shape 82 | num_pts = batched_points.shape[2] 83 | # 坐标乘上缩放比例,input_hw是原始图片大小 84 | rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w) 85 | 86 | # 每一个提示对应的点数不能超过6,未超过6个的提示用-1填充到6个 87 | if num_pts > self.decoder_max_num_input_points: 88 | rescaled_batched_points = rescaled_batched_points[ 89 | :, :, : self.decoder_max_num_input_points, : 90 | ] 91 | batched_point_labels = batched_point_labels[ 92 | :, :, : self.decoder_max_num_input_points 93 | ] 94 | elif num_pts < self.decoder_max_num_input_points: 95 | # pad里每两个元素是一组,表示在第1,2,3...个维度的前后填充的个数 96 | rescaled_batched_points = F.pad( 97 | rescaled_batched_points, 98 | (0, 0, 0, self.decoder_max_num_input_points - num_pts), 99 | value=-1.0, 100 | ) 101 | batched_point_labels = F.pad( 102 | batched_point_labels, 103 | (0, self.decoder_max_num_input_points - num_pts), 104 | value=-1.0, 105 | ) 106 | 107 | sparse_embeddings = self.prompt_encoder( 108 | rescaled_batched_points.reshape( 109 | batch_size * max_num_queries, self.decoder_max_num_input_points, 2 110 | ), 111 | batched_point_labels.reshape( 112 | batch_size * max_num_queries, self.decoder_max_num_input_points 113 | ), 114 | ) 115 | sparse_embeddings = sparse_embeddings.view( 116 | batch_size, 117 | max_num_queries, 118 | sparse_embeddings.shape[1], 119 | sparse_embeddings.shape[2], 120 | ) 121 | # [B, 3, 256, 256]和[B, 3] 122 | low_res_masks, iou_predictions = self.mask_decoder( 123 | image_embeddings, 124 | self.prompt_encoder.get_dense_pe(), 125 | sparse_prompt_embeddings=sparse_embeddings, 126 | multimask_output=multimask_output, 127 | ) 128 | _, num_predictions, low_res_size, _ = low_res_masks.shape 129 | 130 | if output_w > 0 and output_h > 0: 131 | # 双三次差值 132 | output_masks = F.interpolate( 133 | low_res_masks, (output_h, output_w), mode="bicubic" 134 | ) 135 | output_masks = torch.reshape( 136 | output_masks, 137 | (batch_size, max_num_queries, num_predictions, output_h, output_w), 138 | ) 139 | else: 140 | # 等于什么都没做,还是256*256像素 141 | output_masks = torch.reshape( 142 | low_res_masks, 143 | ( 144 | batch_size, 145 | max_num_queries, 146 | num_predictions, 147 | low_res_size, 148 | low_res_size, 149 | ), 150 | ) 151 | iou_predictions = torch.reshape( 152 | iou_predictions, (batch_size, max_num_queries, num_predictions) 153 | ) 154 | return output_masks, iou_predictions 155 | 156 | def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int): 157 | return torch.stack( 158 | [ 159 | torch.where( 160 | batched_points[..., 0] >= 0, 161 | batched_points[..., 0] * self.image_encoder.img_size / input_w, 162 | -1.0, 163 | ), 164 | torch.where( 165 | batched_points[..., 1] >= 0, 166 | batched_points[..., 1] * self.image_encoder.img_size / input_h, 167 | -1.0, 168 | ), 169 | ], 170 | dim=-1, 171 | ) 172 | 173 | @torch.jit.export 174 | def get_image_embeddings(self, batched_images, batched_depths) -> torch.Tensor: 175 | """ 176 | Predicts masks end-to-end from provided images and prompts. 177 | If prompts are not known in advance, using SamPredictor is 178 | recommended over calling the model directly. 179 | 180 | Arguments: 181 | batched_images: A tensor of shape [B, 3, H, W] 182 | batched_images: A tensor of shape [B, 1, H, W] 183 | Returns: 184 | List of image embeddings each of of shape [B, C(i), H(i), W(i)]. 185 | The last embedding corresponds to the final layer. 186 | """ 187 | batched_rgbd = self.preprocess(torch.cat((batched_images, batched_depths), dim=1)) 188 | batched_images = batched_rgbd[:, :3, ...] 189 | batched_depths = batched_rgbd[:, 3:, ...] 190 | return self.image_encoder(batched_images, batched_depths) 191 | 192 | def forward( 193 | self, 194 | batched_images: torch.Tensor, 195 | batched_depths: torch.Tensor, 196 | batched_points: torch.Tensor, 197 | batched_point_labels: torch.Tensor, 198 | scale_to_original_image_size: bool = True, 199 | ) -> Tuple[torch.Tensor, torch.Tensor]: 200 | """ 201 | Predicts masks end-to-end from provided images and prompts. 202 | If prompts are not known in advance, using SamPredictor is 203 | recommended over calling the model directly. 204 | 205 | Arguments: 206 | batched_images: A tensor of shape [B, 3, H, W] 207 | batched_depths: A tensor of shape [B, 1, H, W] 208 | batched_points: A tensor of shape [B, num_queries, max_num_pts, 2],每一个提示可以对应多个点 209 | batched_point_labels: A tensor of shape [B, num_queries, max_num_pts] 210 | 211 | Returns: 212 | A list tuples of two tensors where the ith element is by considering the first i+1 points. 213 | low_res_mask: A tensor of shape [B, 256, 256] of predicted masks 214 | iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores 215 | """ 216 | batch_size, _, input_h, input_w = batched_images.shape 217 | # [1, 256, 64, 64] 218 | image_embeddings = self.get_image_embeddings(batched_images, batched_depths) 219 | return self.predict_masks( 220 | image_embeddings, 221 | batched_points, 222 | batched_point_labels, 223 | multimask_output=True, 224 | input_h=input_h, 225 | input_w=input_w, 226 | output_h=input_h if scale_to_original_image_size else -1, 227 | output_w=input_w if scale_to_original_image_size else -1, 228 | ) 229 | 230 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 231 | """Normalize pixel values and pad to a square input.""" 232 | if ( 233 | x.shape[2] != self.image_encoder.img_size 234 | or x.shape[3] != self.image_encoder.img_size 235 | ): 236 | x = F.interpolate( 237 | x, 238 | (self.image_encoder.img_size, self.image_encoder.img_size), 239 | mode="bilinear", 240 | ) 241 | return (x - self.pixel_mean) / self.pixel_std 242 | 243 | 244 | def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None): 245 | img_size = 1024 246 | encoder_patch_size = 16 247 | encoder_depth = 12 248 | encoder_mlp_ratio = 4.0 249 | encoder_neck_dims = [256, 256] 250 | decoder_max_num_input_points = 6 251 | decoder_transformer_depth = 2 252 | decoder_transformer_mlp_dim = 2048 253 | decoder_num_heads = 8 254 | decoder_upscaling_layer_dims = [64, 32] 255 | num_multimask_outputs = 3 256 | iou_head_depth = 3 257 | iou_head_hidden_dim = 256 258 | activation = "gelu" 259 | normalization_type = "layer_norm" 260 | normalize_before_activation = False 261 | 262 | assert activation == "relu" or activation == "gelu" 263 | if activation == "relu": 264 | activation_fn = nn.ReLU 265 | else: 266 | activation_fn = nn.GELU 267 | 268 | image_encoder = ImageEncoderViT( 269 | img_size=img_size, 270 | patch_size=encoder_patch_size, 271 | in_chans=3, 272 | patch_embed_dim=encoder_patch_embed_dim, 273 | normalization_type=normalization_type, 274 | depth=encoder_depth, 275 | num_heads=encoder_num_heads, 276 | mlp_ratio=encoder_mlp_ratio, 277 | neck_dims=encoder_neck_dims, 278 | act_layer=activation_fn, 279 | ) 280 | 281 | image_embedding_size = image_encoder.image_embedding_size 282 | encoder_transformer_output_dim = image_encoder.transformer_output_dim 283 | 284 | sam = EfficientSam( 285 | image_encoder=image_encoder, 286 | prompt_encoder=PromptEncoder( 287 | embed_dim=encoder_transformer_output_dim, 288 | image_embedding_size=(image_embedding_size, image_embedding_size), 289 | input_image_size=(img_size, img_size), 290 | ), 291 | decoder_max_num_input_points=decoder_max_num_input_points, 292 | mask_decoder=MaskDecoder( 293 | transformer_dim=encoder_transformer_output_dim, 294 | transformer=TwoWayTransformer( 295 | depth=decoder_transformer_depth, 296 | embedding_dim=encoder_transformer_output_dim, 297 | num_heads=decoder_num_heads, 298 | mlp_dim=decoder_transformer_mlp_dim, 299 | activation=activation_fn, 300 | normalize_before_activation=normalize_before_activation, 301 | ), 302 | num_multimask_outputs=num_multimask_outputs, 303 | activation=activation_fn, 304 | normalization_type=normalization_type, 305 | normalize_before_activation=normalize_before_activation, 306 | iou_head_depth=iou_head_depth - 1, 307 | iou_head_hidden_dim=iou_head_hidden_dim, 308 | upscaling_layer_dims=decoder_upscaling_layer_dims, 309 | ), 310 | pixel_mean=[0.485, 0.456, 0.406, 0.473], 311 | pixel_std=[0.229, 0.224, 0.225, 0.189], 312 | ) 313 | if checkpoint is not None: 314 | with open(checkpoint, "rb") as f: 315 | state_dict = torch.load(f)["model"] 316 | sam.load_state_dict(state_dict, strict=False) 317 | 318 | return sam 319 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/det_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from typing import List, Tuple 4 | from torch import Tensor 5 | 6 | 7 | class BalancedPositiveNegativeSampler(object): 8 | """ 9 | This class samples batches, ensuring that they contain a fixed proportion of positives 10 | """ 11 | 12 | def __init__(self, batch_size_per_image, positive_fraction): 13 | # type: (int, float) -> None 14 | """ 15 | Arguments: 16 | batch_size_per_image (int): number of elements to be selected per image 17 | positive_fraction (float): percentage of positive elements per batch 18 | """ 19 | self.batch_size_per_image = batch_size_per_image 20 | self.positive_fraction = positive_fraction 21 | 22 | def __call__(self, matched_idxs): 23 | # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] 24 | """ 25 | Arguments: 26 | matched idxs: list of tensors containing -1, 0 or positive values. 27 | Each tensor corresponds to a specific image. 28 | -1 values are ignored, 0 are considered as negatives and > 0 as 29 | positives. 30 | 31 | Returns: 32 | pos_idx (list[tensor]) 33 | neg_idx (list[tensor]) 34 | 35 | Returns two lists of binary masks for each image. 36 | The first list contains the positive elements that were selected, 37 | and the second list the negative example. 38 | """ 39 | pos_idx = [] 40 | neg_idx = [] 41 | # 遍历每张图像的matched_idxs 42 | for matched_idxs_per_image in matched_idxs: 43 | # >= 1的为正样本, nonzero返回非零元素索引 44 | # positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) 45 | positive = torch.where(torch.ge(matched_idxs_per_image, 1))[0] 46 | # = 0的为负样本 47 | # negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) 48 | negative = torch.where(torch.eq(matched_idxs_per_image, 0))[0] 49 | 50 | # 指定正样本的数量 51 | num_pos = int(self.batch_size_per_image * self.positive_fraction) 52 | # protect against not enough positive examples 53 | # 如果正样本数量不够就直接采用所有正样本 54 | num_pos = min(positive.numel(), num_pos) 55 | # 指定负样本数量 56 | num_neg = self.batch_size_per_image - num_pos 57 | # protect against not enough negative examples 58 | # 如果负样本数量不够就直接采用所有负样本 59 | num_neg = min(negative.numel(), num_neg) 60 | 61 | # randomly select positive and negative examples 62 | # Returns a random permutation of integers from 0 to n - 1. 63 | # 随机选择指定数量的正负样本 64 | perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] 65 | perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] 66 | 67 | pos_idx_per_image = positive[perm1] 68 | neg_idx_per_image = negative[perm2] 69 | 70 | # create binary mask from indices 71 | pos_idx_per_image_mask = torch.zeros_like( 72 | matched_idxs_per_image, dtype=torch.uint8 73 | ) 74 | neg_idx_per_image_mask = torch.zeros_like( 75 | matched_idxs_per_image, dtype=torch.uint8 76 | ) 77 | 78 | pos_idx_per_image_mask[pos_idx_per_image] = 1 79 | neg_idx_per_image_mask[neg_idx_per_image] = 1 80 | 81 | pos_idx.append(pos_idx_per_image_mask) 82 | neg_idx.append(neg_idx_per_image_mask) 83 | 84 | return pos_idx, neg_idx 85 | 86 | 87 | @torch.jit._script_if_tracing 88 | def encode_boxes(reference_boxes, proposals, weights): 89 | # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 90 | """ 91 | Encode a set of proposals with respect to some 92 | reference boxes 93 | 94 | Arguments: 95 | reference_boxes (Tensor): reference boxes(gt) 96 | proposals (Tensor): boxes to be encoded(anchors) 97 | weights: 98 | """ 99 | 100 | # perform some unpacking to make it JIT-fusion friendly 101 | wx = weights[0] 102 | wy = weights[1] 103 | ww = weights[2] 104 | wh = weights[3] 105 | 106 | # unsqueeze() 107 | # Returns a new tensor with a dimension of size one inserted at the specified position. 108 | proposals_x1 = proposals[:, 0].unsqueeze(1) 109 | proposals_y1 = proposals[:, 1].unsqueeze(1) 110 | proposals_x2 = proposals[:, 2].unsqueeze(1) 111 | proposals_y2 = proposals[:, 3].unsqueeze(1) 112 | 113 | reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1) 114 | reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1) 115 | reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1) 116 | reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1) 117 | 118 | # implementation starts here 119 | # parse widths and heights 120 | ex_widths = proposals_x2 - proposals_x1 121 | ex_heights = proposals_y2 - proposals_y1 122 | # parse coordinate of center point 123 | ex_ctr_x = proposals_x1 + 0.5 * ex_widths 124 | ex_ctr_y = proposals_y1 + 0.5 * ex_heights 125 | 126 | gt_widths = reference_boxes_x2 - reference_boxes_x1 127 | gt_heights = reference_boxes_y2 - reference_boxes_y1 128 | gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths 129 | gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights 130 | 131 | targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths 132 | targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights 133 | targets_dw = ww * torch.log(gt_widths / ex_widths) 134 | targets_dh = wh * torch.log(gt_heights / ex_heights) 135 | 136 | targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) 137 | return targets 138 | 139 | 140 | class BoxCoder(object): 141 | """ 142 | This class encodes and decodes a set of bounding boxes into 143 | the representation used for training the regressors. 144 | """ 145 | 146 | def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): 147 | # type: (Tuple[float, float, float, float], float) -> None 148 | """ 149 | Arguments: 150 | weights (4-element tuple) 151 | bbox_xform_clip (float) 152 | """ 153 | self.weights = weights 154 | self.bbox_xform_clip = bbox_xform_clip 155 | 156 | def encode(self, reference_boxes, proposals): 157 | # type: (List[Tensor], List[Tensor]) -> List[Tensor] 158 | """ 159 | 结合anchors和与之对应的gt计算regression参数 160 | Args: 161 | reference_boxes: List[Tensor] 每个proposal/anchor对应的gt_boxes 162 | proposals: List[Tensor] anchors/proposals 163 | 164 | Returns: regression parameters 165 | 166 | """ 167 | # 统计每张图像的anchors个数,方便后面拼接在一起处理后在分开 168 | # reference_boxes和proposal数据结构相同 169 | boxes_per_image = [len(b) for b in reference_boxes] 170 | reference_boxes = torch.cat(reference_boxes, dim=0) 171 | proposals = torch.cat(proposals, dim=0) 172 | 173 | # targets_dx, targets_dy, targets_dw, targets_dh 174 | targets = self.encode_single(reference_boxes, proposals) 175 | return targets.split(boxes_per_image, 0) 176 | 177 | def encode_single(self, reference_boxes, proposals): 178 | """ 179 | Encode a set of proposals with respect to some 180 | reference boxes 181 | 182 | Arguments: 183 | reference_boxes (Tensor): reference boxes 184 | proposals (Tensor): boxes to be encoded 185 | """ 186 | dtype = reference_boxes.dtype 187 | device = reference_boxes.device 188 | weights = torch.as_tensor(self.weights, dtype=dtype, device=device) 189 | targets = encode_boxes(reference_boxes, proposals, weights) 190 | 191 | return targets 192 | 193 | def decode(self, rel_codes, boxes): 194 | # type: (Tensor, List[Tensor]) -> Tensor 195 | """ 196 | 197 | Args: 198 | rel_codes: bbox regression parameters 199 | boxes: anchors/proposals 200 | 201 | Returns: 202 | 203 | """ 204 | assert isinstance(boxes, (list, tuple)) 205 | assert isinstance(rel_codes, torch.Tensor) 206 | boxes_per_image = [b.size(0) for b in boxes] 207 | concat_boxes = torch.cat(boxes, dim=0) 208 | 209 | box_sum = 0 210 | for val in boxes_per_image: 211 | box_sum += val 212 | 213 | # 将预测的bbox回归参数应用到对应anchors上得到预测bbox的坐标 214 | pred_boxes = self.decode_single( 215 | rel_codes, concat_boxes 216 | ) 217 | 218 | # 防止pred_boxes为空时导致reshape报错 219 | if box_sum > 0: 220 | pred_boxes = pred_boxes.reshape(box_sum, -1, 4) 221 | 222 | return pred_boxes 223 | 224 | def decode_single(self, rel_codes, boxes): 225 | """ 226 | From a set of original boxes and encoded relative box offsets, 227 | get the decoded boxes. 228 | 229 | Arguments: 230 | rel_codes (Tensor): encoded boxes (bbox regression parameters) 231 | boxes (Tensor): reference boxes (anchors/proposals) 232 | """ 233 | boxes = boxes.to(rel_codes.dtype) 234 | 235 | # xmin, ymin, xmax, ymax 236 | widths = boxes[:, 2] - boxes[:, 0] # anchor/proposal宽度 237 | heights = boxes[:, 3] - boxes[:, 1] # anchor/proposal高度 238 | ctr_x = boxes[:, 0] + 0.5 * widths # anchor/proposal中心x坐标 239 | ctr_y = boxes[:, 1] + 0.5 * heights # anchor/proposal中心y坐标 240 | 241 | wx, wy, ww, wh = self.weights # RPN中为[1,1,1,1], fastrcnn中为[10,10,5,5] 242 | dx = rel_codes[:, 0::4] / wx # 预测anchors/proposals的中心坐标x回归参数 243 | dy = rel_codes[:, 1::4] / wy # 预测anchors/proposals的中心坐标y回归参数 244 | dw = rel_codes[:, 2::4] / ww # 预测anchors/proposals的宽度回归参数 245 | dh = rel_codes[:, 3::4] / wh # 预测anchors/proposals的高度回归参数 246 | 247 | # limit max value, prevent sending too large values into torch.exp() 248 | # self.bbox_xform_clip=math.log(1000. / 16) 4.135 249 | dw = torch.clamp(dw, max=self.bbox_xform_clip) 250 | dh = torch.clamp(dh, max=self.bbox_xform_clip) 251 | 252 | pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] 253 | pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] 254 | pred_w = torch.exp(dw) * widths[:, None] 255 | pred_h = torch.exp(dh) * heights[:, None] 256 | 257 | # xmin 258 | pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w 259 | # ymin 260 | pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h 261 | # xmax 262 | pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w 263 | # ymax 264 | pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h 265 | 266 | pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) 267 | return pred_boxes 268 | 269 | 270 | class Matcher(object): 271 | BELOW_LOW_THRESHOLD = -1 272 | BETWEEN_THRESHOLDS = -2 273 | 274 | __annotations__ = { 275 | 'BELOW_LOW_THRESHOLD': int, 276 | 'BETWEEN_THRESHOLDS': int, 277 | } 278 | 279 | def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): 280 | # type: (float, float, bool) -> None 281 | """ 282 | Args: 283 | high_threshold (float): quality values greater than or equal to 284 | this value are candidate matches. 285 | low_threshold (float): a lower quality threshold used to stratify 286 | matches into three levels: 287 | 1) matches >= high_threshold 288 | 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 289 | 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) 290 | allow_low_quality_matches (bool): if True, produce additional matches 291 | for predictions that have only low-quality match candidates. See 292 | set_low_quality_matches_ for more details. 293 | """ 294 | self.BELOW_LOW_THRESHOLD = -1 295 | self.BETWEEN_THRESHOLDS = -2 296 | assert low_threshold <= high_threshold 297 | self.high_threshold = high_threshold # 0.7 298 | self.low_threshold = low_threshold # 0.3 299 | self.allow_low_quality_matches = allow_low_quality_matches 300 | 301 | def __call__(self, match_quality_matrix): 302 | """ 303 | 计算anchors与每个gtboxes匹配的iou最大值,并记录索引, 304 | iou= self.low_threshold) & ( 341 | matched_vals < self.high_threshold 342 | ) 343 | # iou小于low_threshold的matches索引置为-1 344 | matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD # -1 345 | 346 | # iou在[low_threshold, high_threshold]之间的matches索引置为-2 347 | matches[between_thresholds] = self.BETWEEN_THRESHOLDS # -2 348 | 349 | if self.allow_low_quality_matches: 350 | assert all_matches is not None 351 | self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) 352 | 353 | return matches 354 | 355 | def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): 356 | """ 357 | Produce additional matches for predictions that have only low-quality matches. 358 | Specifically, for each ground-truth find the set of predictions that have 359 | maximum overlap with it (including ties); for each prediction in that set, if 360 | it is unmatched, then match it to the ground-truth with which it has the highest 361 | quality value. 362 | """ 363 | # For each gt, find the prediction with which it has highest quality 364 | # 对于每个gt boxes寻找与其iou最大的anchor, 365 | # highest_quality_foreach_gt为匹配到的最大iou值 366 | highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # the dimension to reduce. 367 | 368 | # Find highest quality match available, even if it is low, including ties 369 | # 寻找每个gt boxes与其iou最大的anchor索引,一个gt匹配到的最大iou可能有多个anchor 370 | # gt_pred_pairs_of_highest_quality = torch.nonzero( 371 | # match_quality_matrix == highest_quality_foreach_gt[:, None] 372 | # ) 373 | gt_pred_pairs_of_highest_quality = torch.where( 374 | torch.eq(match_quality_matrix, highest_quality_foreach_gt[:, None]) 375 | ) 376 | # Example gt_pred_pairs_of_highest_quality: 377 | # tensor([[ 0, 39796], 378 | # [ 1, 32055], 379 | # [ 1, 32070], 380 | # [ 2, 39190], 381 | # [ 2, 40255], 382 | # [ 3, 40390], 383 | # [ 3, 41455], 384 | # [ 4, 45470], 385 | # [ 5, 45325], 386 | # [ 5, 46390]]) 387 | # Each row is a (gt index, prediction index) 388 | # Note how gt items 1, 2, 3, and 5 each have two ties 389 | 390 | # gt_pred_pairs_of_highest_quality[:, 0]代表是对应的gt index(不需要) 391 | # pre_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] 392 | pre_inds_to_update = gt_pred_pairs_of_highest_quality[1] 393 | # 保留该anchor匹配gt最大iou的索引,即使iou低于设定的阈值 394 | matches[pre_inds_to_update] = all_matches[pre_inds_to_update] 395 | 396 | 397 | def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True): 398 | """ 399 | very similar to the smooth_l1_loss from pytorch, but with 400 | the extra beta parameter 401 | """ 402 | n = torch.abs(input - target) 403 | # cond = n < beta 404 | cond = torch.lt(n, beta) 405 | loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) 406 | if size_average: 407 | return loss.mean() 408 | return loss.sum() 409 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/faster_rcnn_framework.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | from typing import Tuple, List, Dict, Optional, Union 4 | 5 | import torch 6 | from torch import nn, Tensor 7 | import torch.nn.functional as F 8 | from torchvision.ops import MultiScaleRoIAlign 9 | 10 | from .roi_head import RoIHeads 11 | from .transform import GeneralizedRCNNTransform 12 | from .rpn_function import AnchorsGenerator, RPNHead, RegionProposalNetwork 13 | 14 | 15 | class FasterRCNNBase(nn.Module): 16 | """ 17 | Main class for Generalized R-CNN. 18 | 19 | Arguments: 20 | backbone (nn.Module): 21 | rpn (nn.Module): 22 | roi_heads (nn.Module): takes the features + the proposals from the RPN and computes 23 | detections / masks from it. 24 | transform (nn.Module): performs the data transformation from the inputs to feed into 25 | the model 26 | """ 27 | 28 | def __init__(self, backbone, rpn, roi_heads, transform): 29 | super(FasterRCNNBase, self).__init__() 30 | self.transform = transform 31 | self.backbone = backbone 32 | self.rpn = rpn 33 | self.roi_heads = roi_heads 34 | # used only on torchscript mode 35 | self._has_warned = False 36 | 37 | @torch.jit.unused 38 | def eager_outputs(self, losses, detections): 39 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] 40 | if self.training: 41 | return losses 42 | 43 | return detections 44 | 45 | def forward(self, images, targets=None): 46 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 47 | """ 48 | Arguments: 49 | images (list[Tensor]): images to be processed 50 | targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) 51 | 52 | Returns: 53 | result (list[BoxList] or dict[Tensor]): the output from the model. 54 | During training, it returns a dict[Tensor] which contains the losses. 55 | During testing, it returns list[BoxList] contains additional fields 56 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 57 | 58 | """ 59 | if self.training and targets is None: 60 | raise ValueError("In training mode, targets should be passed") 61 | 62 | if self.training: 63 | assert targets is not None 64 | for target in targets: # 进一步判断传入的target的boxes参数是否符合规定 65 | boxes = target["boxes"] 66 | if isinstance(boxes, torch.Tensor): 67 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4: 68 | raise ValueError("Expected target boxes to be a tensor" 69 | "of shape [N, 4], got {:}.".format( 70 | boxes.shape)) 71 | else: 72 | raise ValueError("Expected target boxes to be of type " 73 | "Tensor, got {:}.".format(type(boxes))) 74 | 75 | original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) 76 | for img in images: 77 | val = img.shape[-2:] 78 | assert len(val) == 2 # 防止输入的是个一维向量 79 | original_image_sizes.append((val[0], val[1])) 80 | # original_image_sizes = [img.shape[-2:] for img in images] 81 | 82 | images, targets = self.transform(images, targets) # 对图像进行预处理 83 | 84 | # print(images.tensors.shape) 85 | features = self.backbone(images.tensors) # 将图像输入backbone得到特征图 86 | if isinstance(features, torch.Tensor): # 若只在一层特征层上预测,将feature放入有序字典中,并编号为‘0’ 87 | features = OrderedDict([('0', features)]) # 若在多层特征层上预测,传入的就是一个有序字典 88 | 89 | # 将特征层以及标注target信息传入rpn中 90 | # proposals: List[Tensor], Tensor_shape: [num_proposals, 4], 91 | # 每个proposals是绝对坐标,且为(x1, y1, x2, y2)格式 92 | proposals, proposal_losses = self.rpn(images, features, targets) 93 | 94 | # 将rpn生成的数据以及标注target信息传入fast rcnn后半部分 95 | detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) 96 | 97 | # 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上) 98 | detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) 99 | 100 | losses = {} 101 | losses.update(detector_losses) 102 | losses.update(proposal_losses) 103 | 104 | if torch.jit.is_scripting(): 105 | if not self._has_warned: 106 | warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") 107 | self._has_warned = True 108 | return losses, detections 109 | else: 110 | return self.eager_outputs(losses, detections) 111 | 112 | # if self.training: 113 | # return losses 114 | # 115 | # return detections 116 | 117 | 118 | class TwoMLPHead(nn.Module): 119 | """ 120 | Standard heads for FPN-based models 121 | 122 | Arguments: 123 | in_channels (int): number of input channels 124 | representation_size (int): size of the intermediate representation 125 | """ 126 | 127 | def __init__(self, in_channels, representation_size): 128 | super(TwoMLPHead, self).__init__() 129 | 130 | self.fc6 = nn.Linear(in_channels, representation_size) 131 | self.fc7 = nn.Linear(representation_size, representation_size) 132 | 133 | def forward(self, x): 134 | x = x.flatten(start_dim=1) 135 | 136 | x = F.relu(self.fc6(x)) 137 | x = F.relu(self.fc7(x)) 138 | 139 | return x 140 | 141 | 142 | class FastRCNNPredictor(nn.Module): 143 | """ 144 | Standard classification + bounding box regression layers 145 | for Fast R-CNN. 146 | 147 | Arguments: 148 | in_channels (int): number of input channels 149 | num_classes (int): number of output classes (including background) 150 | """ 151 | 152 | def __init__(self, in_channels, num_classes): 153 | super(FastRCNNPredictor, self).__init__() 154 | self.cls_score = nn.Linear(in_channels, num_classes) 155 | self.bbox_pred = nn.Linear(in_channels, num_classes * 4) 156 | 157 | def forward(self, x): 158 | if x.dim() == 4: 159 | assert list(x.shape[2:]) == [1, 1] 160 | x = x.flatten(start_dim=1) 161 | scores = self.cls_score(x) 162 | bbox_deltas = self.bbox_pred(x) 163 | 164 | return scores, bbox_deltas 165 | 166 | 167 | class FasterRCNN(FasterRCNNBase): 168 | """ 169 | Implements Faster R-CNN. 170 | 171 | The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each 172 | image, and should be in 0-1 range. Different images can have different sizes. 173 | 174 | The behavior of the model changes depending if it is in training or evaluation mode. 175 | 176 | During training, the model expects both the input tensors, as well as a targets (list of dictionary), 177 | containing: 178 | - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values 179 | between 0 and H and 0 and W 180 | - labels (Int64Tensor[N]): the class label for each ground-truth box 181 | 182 | The model returns a Dict[Tensor] during training, containing the classification and regression 183 | losses for both the RPN and the R-CNN. 184 | 185 | During inference, the model requires only the input tensors, and returns the post-processed 186 | predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as 187 | follows: 188 | - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 189 | 0 and H and 0 and W 190 | - labels (Int64Tensor[N]): the predicted labels for each image 191 | - scores (Tensor[N]): the scores or each prediction 192 | 193 | Arguments: 194 | backbone (nn.Module): the network used to compute the features for the model. 195 | It should contain a out_channels attribute, which indicates the number of output 196 | channels that each feature map has (and it should be the same for all feature maps). 197 | The backbone should return a single Tensor or and OrderedDict[Tensor]. 198 | num_classes (int): number of output classes of the model (including the background). 199 | If box_predictor is specified, num_classes should be None. 200 | min_size (int): minimum size of the image to be rescaled before feeding it to the backbone 201 | max_size (int): maximum size of the image to be rescaled before feeding it to the backbone 202 | image_mean (Tuple[float, float, float]): mean values used for input normalization. 203 | They are generally the mean values of the dataset on which the backbone has been trained 204 | on 205 | image_std (Tuple[float, float, float]): std values used for input normalization. 206 | They are generally the std values of the dataset on which the backbone has been trained on 207 | rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature 208 | maps. 209 | rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN 210 | rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training 211 | rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing 212 | rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training 213 | rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing 214 | rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals 215 | rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be 216 | considered as positive during training of the RPN. 217 | rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be 218 | considered as negative during training of the RPN. 219 | rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN 220 | for computing the loss 221 | rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training 222 | of the RPN 223 | rpn_score_thresh (float): during inference, only return proposals with a classification score 224 | greater than rpn_score_thresh 225 | box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 226 | the locations indicated by the bounding boxes 227 | box_head (nn.Module): module that takes the cropped feature maps as input 228 | box_predictor (nn.Module): module that takes the output of box_head and returns the 229 | classification logits and box regression deltas. 230 | box_score_thresh (float): during inference, only return proposals with a classification score 231 | greater than box_score_thresh 232 | box_nms_thresh (float): NMS threshold for the prediction head. Used during inference 233 | box_detections_per_img (int): maximum number of detections per image, for all classes. 234 | box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be 235 | considered as positive during training of the classification head 236 | box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be 237 | considered as negative during training of the classification head 238 | box_batch_size_per_image (int): number of proposals that are sampled during training of the 239 | classification head 240 | box_positive_fraction (float): proportion of positive proposals in a mini-batch during training 241 | of the classification head 242 | bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the 243 | bounding boxes 244 | 245 | """ 246 | 247 | def __init__(self, backbone, num_classes=None, 248 | # transform parameter 249 | min_size=1024, max_size=1024, # 预处理resize时限制的最小尺寸与最大尺寸 250 | image_mean=None, image_std=None, # 预处理normalize时使用的均值和方差 251 | # RPN parameters 252 | rpn_anchor_generator=None, rpn_head=None, 253 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, # rpn中在nms处理前保留的proposal数(根据score) 254 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, # rpn中在nms处理后保留的proposal数 255 | rpn_nms_thresh=0.7, # rpn中进行nms处理时使用的iou阈值 256 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, # rpn计算损失时,采集正负样本设置的阈值 257 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # rpn计算损失时采样的样本数,以及正样本占总样本的比例 258 | rpn_score_thresh=0.0, 259 | # Box parameters 260 | box_roi_pool=None, box_head=None, box_predictor=None, 261 | # 移除低目标概率 fast rcnn中进行nms处理的阈值 对预测结果根据score排序取前100个目标 262 | box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, 263 | box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, # fast rcnn计算误差时,采集正负样本设置的阈值 264 | box_batch_size_per_image=512, box_positive_fraction=0.25, # fast rcnn计算误差时采样的样本数,以及正样本占所有样本的比例 265 | bbox_reg_weights=None): 266 | if not hasattr(backbone, "out_channels"): 267 | raise ValueError( 268 | "backbone should contain an attribute out_channels" 269 | "specifying the number of output channels (assumed to be the" 270 | "same for all the levels" 271 | ) 272 | 273 | assert isinstance(rpn_anchor_generator, (AnchorsGenerator, type(None))) 274 | assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) 275 | 276 | if num_classes is not None: 277 | if box_predictor is not None: 278 | raise ValueError("num_classes should be None when box_predictor " 279 | "is specified") 280 | else: 281 | if box_predictor is None: 282 | raise ValueError("num_classes should not be None when box_predictor " 283 | "is not specified") 284 | 285 | # 预测特征层的channels 286 | out_channels = backbone.out_channels 287 | 288 | # 若anchor生成器为空,则自动生成针对resnet50_fpn的anchor生成器 289 | if rpn_anchor_generator is None: 290 | anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) 291 | aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) 292 | rpn_anchor_generator = AnchorsGenerator( 293 | anchor_sizes, aspect_ratios 294 | ) 295 | 296 | # 生成RPN通过滑动窗口预测网络部分 297 | if rpn_head is None: 298 | rpn_head = RPNHead( 299 | out_channels, rpn_anchor_generator.num_anchors_per_location()[0] 300 | ) 301 | 302 | # 默认rpn_pre_nms_top_n_train = 2000, rpn_pre_nms_top_n_test = 1000, 303 | # 默认rpn_post_nms_top_n_train = 2000, rpn_post_nms_top_n_test = 1000, 304 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) 305 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) 306 | 307 | # 定义整个RPN框架 308 | rpn = RegionProposalNetwork( 309 | rpn_anchor_generator, rpn_head, 310 | rpn_fg_iou_thresh, rpn_bg_iou_thresh, 311 | rpn_batch_size_per_image, rpn_positive_fraction, 312 | rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, 313 | score_thresh=rpn_score_thresh) 314 | 315 | # Multi-scale RoIAlign pooling 316 | if box_roi_pool is None: 317 | box_roi_pool = MultiScaleRoIAlign( 318 | featmap_names=['0', '1', '2', '3'], # 在哪些特征层进行roi pooling 319 | output_size=[7, 7], 320 | sampling_ratio=2) 321 | 322 | # fast RCNN中roi pooling后的展平处理两个全连接层部分 323 | if box_head is None: 324 | resolution = box_roi_pool.output_size[0] # 默认等于7 325 | representation_size = 1024 326 | box_head = TwoMLPHead( 327 | out_channels * resolution ** 2, 328 | representation_size 329 | ) 330 | 331 | # 在box_head的输出上预测部分 332 | if box_predictor is None: 333 | representation_size = 1024 334 | box_predictor = FastRCNNPredictor( 335 | representation_size, 336 | num_classes) 337 | 338 | # 将roi pooling, box_head以及box_predictor结合在一起 339 | roi_heads = RoIHeads( 340 | # box 341 | box_roi_pool, box_head, box_predictor, 342 | box_fg_iou_thresh, box_bg_iou_thresh, # 0.5 0.5 343 | box_batch_size_per_image, box_positive_fraction, # 512 0.25 344 | bbox_reg_weights, 345 | box_score_thresh, box_nms_thresh, box_detections_per_img) # 0.05 0.5 100 346 | 347 | if image_mean is None: 348 | image_mean = [0.485, 0.456, 0.406] 349 | if image_std is None: 350 | image_std = [0.229, 0.224, 0.225] 351 | 352 | # 对数据进行标准化,缩放,打包成batch等处理部分 353 | transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) 354 | 355 | super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform) 356 | --------------------------------------------------------------------------------