├── PlaneSAM ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── loss.cpython-39.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── box_ops.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── eval_tools.cpython-39.pyc │ │ ├── make_prompt.cpython-39.pyc │ │ ├── postprocess.cpython-39.pyc │ │ └── visual_tools.cpython-39.pyc │ ├── postprocess.py │ ├── visual_tools.py │ ├── loss.py │ ├── box_ops.py │ ├── train_tools.py │ ├── eval_tools.py │ └── make_prompt.py ├── model │ ├── __pycache__ │ │ ├── mlp.cpython-37.pyc │ │ ├── mlp.cpython-39.pyc │ │ ├── mlp.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── efficient_sam.cpython-37.pyc │ │ ├── efficient_sam.cpython-39.pyc │ │ ├── efficient_sam.cpython-310.pyc │ │ ├── build_efficient_sam.cpython-310.pyc │ │ ├── build_efficient_sam.cpython-37.pyc │ │ ├── build_efficient_sam.cpython-39.pyc │ │ ├── two_way_transformer.cpython-310.pyc │ │ ├── two_way_transformer.cpython-37.pyc │ │ ├── two_way_transformer.cpython-39.pyc │ │ ├── efficient_sam_decoder.cpython-310.pyc │ │ ├── efficient_sam_decoder.cpython-37.pyc │ │ ├── efficient_sam_decoder.cpython-39.pyc │ │ ├── efficient_sam_encoder.cpython-310.pyc │ │ ├── efficient_sam_encoder.cpython-37.pyc │ │ └── efficient_sam_encoder.cpython-39.pyc │ ├── __init__.py │ ├── build_efficient_sam.py │ ├── mlp.py │ ├── two_way_transformer.py │ ├── efficient_sam_encoder.py │ ├── efficient_sam_decoder.py │ └── efficient_sam.py ├── FasterRCNN │ ├── __init__.py │ ├── __pycache__ │ │ ├── boxes.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── boxes.cpython-310.pyc │ │ ├── roi_head.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── det_utils.cpython-310.pyc │ │ ├── det_utils.cpython-39.pyc │ │ ├── image_list.cpython-39.pyc │ │ ├── roi_head.cpython-310.pyc │ │ ├── transform.cpython-310.pyc │ │ ├── transform.cpython-39.pyc │ │ ├── image_list.cpython-310.pyc │ │ ├── rpn_function.cpython-310.pyc │ │ ├── rpn_function.cpython-39.pyc │ │ ├── build_FasterRCNN.cpython-39.pyc │ │ ├── faster_rcnn_framework.cpython-310.pyc │ │ └── faster_rcnn_framework.cpython-39.pyc │ ├── image_list.py │ ├── build_FasterRCNN.py │ ├── boxes.py │ ├── transform.py │ ├── det_utils.py │ └── faster_rcnn_framework.py ├── __pycache__ │ ├── Nyuv2Dataset.cpython-39.pyc │ ├── PlaneDataset.cpython-37.pyc │ ├── PlaneDataset.cpython-39.pyc │ ├── S2D3DSDataset.cpython-39.pyc │ ├── RecordReaderAll.cpython-36.pyc │ └── RecordReaderAll.cpython-39.pyc ├── backbone │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── vgg_model.cpython-310.pyc │ │ ├── vgg_model.cpython-39.pyc │ │ ├── mobilenetv2_model.cpython-39.pyc │ │ ├── mobilenetv2_model.cpython-310.pyc │ │ ├── resnet101_fpn_model.cpython-39.pyc │ │ ├── resnet50_fpn_model.cpython-310.pyc │ │ ├── resnet50_fpn_model.cpython-39.pyc │ │ ├── feature_pyramid_network.cpython-310.pyc │ │ └── feature_pyramid_network.cpython-39.pyc │ ├── __init__.py │ ├── vgg_model.py │ ├── mobilenetv2_model.py │ ├── resnet101_fpn_model.py │ └── feature_pyramid_network.py ├── requirements.txt ├── README.md ├── S2D3DSDataset.py ├── Nyuv2Dataset.py ├── eval.py ├── PlaneDataset.py ├── train.py └── visual.py └── README.md /PlaneSAM/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/mlp.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/mlp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/mlp.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__init__.py: -------------------------------------------------------------------------------- 1 | from .faster_rcnn_framework import FasterRCNN, FastRCNNPredictor 2 | from .rpn_function import AnchorsGenerator 3 | -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/mlp.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/Nyuv2Dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/Nyuv2Dataset.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/PlaneDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/PlaneDataset.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/PlaneDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/PlaneDataset.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/S2D3DSDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/S2D3DSDataset.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/box_ops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/box_ops.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/RecordReaderAll.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/RecordReaderAll.cpython-36.pyc -------------------------------------------------------------------------------- /PlaneSAM/__pycache__/RecordReaderAll.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/__pycache__/RecordReaderAll.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/eval_tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/eval_tools.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/boxes.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/vgg_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/vgg_model.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/vgg_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/vgg_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/make_prompt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/make_prompt.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/postprocess.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/postprocess.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/utils/__pycache__/visual_tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/utils/__pycache__/visual_tools.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/det_utils.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/roi_head.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/transform.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/transform.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/image_list.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/rpn_function.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/build_efficient_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/build_efficient_sam.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/build_efficient_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/build_efficient_sam.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/build_efficient_sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/build_efficient_sam.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/two_way_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/two_way_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/two_way_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/two_way_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/two_way_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/two_way_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/build_FasterRCNN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/build_FasterRCNN.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/mobilenetv2_model.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/resnet101_fpn_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/resnet101_fpn_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/resnet50_fpn_model.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/model/__pycache__/efficient_sam_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/FasterRCNN/__pycache__/faster_rcnn_framework.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-310.pyc -------------------------------------------------------------------------------- /PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCSI2022/PlaneSAM/HEAD/PlaneSAM/backbone/__pycache__/feature_pyramid_network.cpython-39.pyc -------------------------------------------------------------------------------- /PlaneSAM/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | from .build_efficient_sam import ( 5 | build_efficient_sam_vitt, 6 | build_efficient_sam_vits, 7 | ) 8 | -------------------------------------------------------------------------------- /PlaneSAM/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet101_fpn_model import resnet101_fpn_backbone 2 | from .mobilenetv2_model import MobileNetV2 3 | from .vgg_model import vgg 4 | from .feature_pyramid_network import LastLevelMaxPool, BackboneWithFPN 5 | -------------------------------------------------------------------------------- /PlaneSAM/model/build_efficient_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .efficient_sam import build_efficient_sam 8 | 9 | def build_efficient_sam_vitt(): 10 | return build_efficient_sam( 11 | encoder_patch_embed_dim=192, 12 | encoder_num_heads=3, 13 | checkpoint=None, 14 | ) 15 | 16 | 17 | def build_efficient_sam_vits(): 18 | return build_efficient_sam( 19 | encoder_patch_embed_dim=384, 20 | encoder_num_heads=6, 21 | checkpoint=None, 22 | ) 23 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/image_list.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from torch import Tensor 3 | 4 | 5 | class ImageList(object): 6 | """ 7 | Structure that holds a list of images (of possibly 8 | varying sizes) as a single tensor. 9 | This works by padding the images to the same size, 10 | and storing in a field the original sizes of each image 11 | """ 12 | 13 | def __init__(self, tensors, image_sizes): 14 | # type: (Tensor, List[Tuple[int, int]]) -> None 15 | """ 16 | Arguments: 17 | tensors (tensor) padding后的图像数据 18 | image_sizes (list[tuple[int, int]]) padding前的图像尺寸 19 | """ 20 | self.tensors = tensors 21 | self.image_sizes = image_sizes 22 | 23 | def to(self, device): 24 | # type: (Device) -> ImageList # noqa 25 | cast_tensor = self.tensors.to(device) 26 | return ImageList(cast_tensor, self.image_sizes) 27 | 28 | -------------------------------------------------------------------------------- /PlaneSAM/model/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from torch import nn 4 | 5 | 6 | # Lightly adapted from 7 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 8 | class MLPBlock(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | hidden_dim: int, 13 | output_dim: int, 14 | num_layers: int, 15 | act: Type[nn.Module], 16 | ) -> None: 17 | super().__init__() 18 | self.num_layers = num_layers 19 | h = [hidden_dim] * (num_layers - 1) 20 | self.layers = nn.ModuleList( 21 | nn.Sequential(nn.Linear(n, k), act()) 22 | for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) 23 | ) 24 | self.fc = nn.Linear(hidden_dim, output_dim) 25 | 26 | def forward(self, x): 27 | for layer in self.layers: 28 | x = layer(x) 29 | return self.fc(x) 30 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/build_FasterRCNN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from FasterRCNN import FasterRCNN, FastRCNNPredictor 4 | from backbone import resnet101_fpn_backbone 5 | 6 | def create_model(num_classes, load_pretrain_weights=False): 7 | backbone = resnet101_fpn_backbone(pretrain_path="", 8 | norm_layer=torch.nn.BatchNorm2d, 9 | trainable_layers=5) 10 | # 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数 11 | model = FasterRCNN(backbone=backbone, num_classes=91) 12 | 13 | if load_pretrain_weights: 14 | weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth", map_location='cpu') 15 | missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False) 16 | if len(missing_keys) != 0 or len(unexpected_keys) != 0: 17 | print("missing_keys: ", missing_keys) 18 | print("unexpected_keys: ", unexpected_keys) 19 | 20 | # get number of input features for the classifier 21 | in_features = model.roi_heads.box_predictor.cls_score.in_features 22 | # replace the pre-trained head with a new one 23 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 24 | 25 | return model -------------------------------------------------------------------------------- /PlaneSAM/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | addict==2.4.0 3 | certifi==2024.7.4 4 | charset-normalizer==3.3.2 5 | cloudpickle==3.0.0 6 | contourpy==1.2.1 7 | cycler==0.12.1 8 | Cython==3.0.11 9 | filelock==3.15.4 10 | fonttools==4.53.1 11 | fsspec==2024.6.1 12 | fvcore==0.1.5.post20221221 13 | grpcio==1.65.1 14 | huggingface-hub==0.24.5 15 | idna==3.7 16 | imageio==2.34.2 17 | importlib_metadata==8.2.0 18 | importlib_resources==6.4.0 19 | iopath==0.1.10 20 | joblib==1.4.2 21 | kiwisolver==1.4.5 22 | lazy_loader==0.4 23 | Markdown==3.6 24 | MarkupSafe==2.1.5 25 | matplotlib==3.9.1 26 | MultiScaleDeformableAttention==1.0 27 | networkx==3.2.1 28 | numpy==1.26.4 29 | opencv-python==4.10.0.84 30 | packaging==24.1 31 | pillow==10.4.0 32 | platformdirs==4.2.2 33 | portalocker==3.1.1 34 | protobuf==4.25.4 35 | pycocotools==2.0.8 36 | pyparsing==3.1.2 37 | python-dateutil==2.9.0.post0 38 | PyYAML==6.0.2 39 | requests==2.32.3 40 | safetensors==0.4.4 41 | scikit-image==0.24.0 42 | scikit-learn==1.5.1 43 | scipy==1.13.1 44 | six==1.16.0 45 | submitit==1.5.1 46 | tabulate==0.9.0 47 | tensorboard==2.17.0 48 | tensorboard-data-server==0.7.2 49 | termcolor==2.4.0 50 | threadpoolctl==3.5.0 51 | tifffile==2024.7.24 52 | timm==1.0.8 53 | tomli==2.0.1 54 | torch==1.12.1+cu116 55 | torchvision==0.13.1+cu116 56 | tqdm==4.66.4 57 | typing_extensions==4.12.2 58 | urllib3==2.2.2 59 | Werkzeug==3.0.3 60 | yacs==0.1.8 61 | yapf==0.40.2 62 | zipp==3.19.2 63 | -------------------------------------------------------------------------------- /PlaneSAM/utils/postprocess.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import utils.box_ops as box_ops 5 | 6 | class PostProcess(nn.Module): 7 | """ This module converts the model's output into the format expected by the coco api""" 8 | @torch.no_grad() 9 | def forward(self, outputs, target_sizes): 10 | """ Perform the computation 11 | Parameters: 12 | outputs: raw outputs of the model 13 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 14 | For evaluation, this must be the original image size (before any data augmentation) 15 | For visualization, this should be the image size after data augment, but before padding 16 | """ 17 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 18 | 19 | assert len(out_logits) == len(target_sizes) 20 | assert target_sizes.shape[1] == 2 21 | 22 | prob = F.softmax(out_logits, -1) 23 | scores, labels = prob[..., :-1].max(-1) 24 | 25 | # convert to [x0, y0, x1, y1] format 26 | boxes = box_ops.box_cxcywh_to_xywh(out_bbox) 27 | # and from relative [0, 1] to absolute [0, height] coordinates 28 | img_h, img_w = target_sizes.unbind(1) 29 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 30 | boxes = boxes * scale_fct[:, None, :] 31 | 32 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] 33 | 34 | return results -------------------------------------------------------------------------------- /PlaneSAM/README.md: -------------------------------------------------------------------------------- 1 | # Multimodal plane instance segmentation with the Segment Anything Model 2 | 3 | 4 | 5 | ## Getting Start 6 | Build the Pytorch Environment: 7 | ```bash 8 | conda create -n PlaneSAM python=3.9.16 9 | conda activate PlaneSAM 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Data Preparation 14 | 15 | We train and test our network using the same plane dataset as [PlaneTR](https://github.com/IceTTTb/PlaneTR3D). 16 | You can access the dataset from [here](https://pan.baidu.com/s/1pyx-Ou3SLq7XG5NIMqC2cQ?pwd=in3b) 17 | 18 | ## Training 19 | 20 | Our training process consists of two steps: 21 | - First, we pretrain on a large-scale RGB-D dataset. The pretrained weights can be obtained from [here](https://pan.baidu.com/s/1NarX09MpkDDsBr7WWI0mzw?pwd=pvkj) and placed in the weights directory. 22 | - Second, load the pretrained weights into the network and run the train.py script.The trained weights can be obtained from [here](https://pan.baidu.com/s/1O4zygzKL13obNMAB2kuxkg?pwd=nrwj). 23 | 24 | ## Evaluation 25 | 26 | During the evaluation, we use Faster R-CNN as the plane detector. The trained weights can be obtained from [here](https://pan.baidu.com/s/1uO1pqs2B4R5IPKQgU0fPTg?pwd=26jr) and placed in the weights directory.The unseen test dataset can be obtained from [here](https://pan.baidu.com/s/1ywNjTRCzXfuxb2VHPGTGzg?pwd=9qcm). 27 | To evaluate the plane segmentation capabilities of PlaneSAM, please run the eval.py script. 28 | 29 | ## Acknowledgements 30 | This code is based on the [EfficientSAM](https://github.com/yformer/EfficientSAM) repository. We would like to acknowledge the authors for their work. 31 | 32 | ## Additional Note 33 | 34 | Due to the author's current busy schedule, we apologize for the possibly poor code quality. Optimizations will be made in the future. If you encounter any questions or bugs in the code, feel free to ask. 35 | -------------------------------------------------------------------------------- /PlaneSAM/utils/visual_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | colors = np.array([ 5 | [255, 0, 0], 6 | [0, 255, 0], 7 | [0, 0, 255], 8 | [255, 255, 0], 9 | [255, 0, 255], 10 | [0, 255, 255], 11 | [255, 128, 0], 12 | [128, 0, 255], 13 | [0, 128, 255], 14 | [255, 0, 128], 15 | [128, 255, 0], 16 | [0, 255, 128], 17 | [255, 128, 128], 18 | [128, 255, 128], 19 | [128, 128, 255], 20 | [255, 255, 128], 21 | [255, 128, 255], 22 | [128, 255, 255], 23 | [192, 192, 192], 24 | [128, 0, 0], 25 | [0, 128, 0], 26 | [0, 0, 128], 27 | [128, 128, 0], 28 | [128, 0, 128], 29 | [0, 128, 128], 30 | [64, 0, 0], 31 | [0, 64, 0], 32 | [0, 0, 64], 33 | [64, 64, 0], 34 | [64, 0, 64], 35 | [0,64, 64], 36 | [0, 0, 0] 37 | ], dtype=np.uint8) 38 | 39 | def map_masks_to_colors(masks): 40 | """ 41 | :param masks: [num_planes, H, W] 42 | :return: rgb_image[H, W, 3] 43 | """ 44 | num_masks = len(masks) 45 | # 随机生成类别对应的颜色 46 | 47 | # 创建空白RGB图像 48 | height, width = masks[0].shape 49 | rgb_image = np.zeros((height, width, 3), dtype=np.uint8) 50 | 51 | for i, mask in enumerate(masks): 52 | color = colors[i] 53 | # mask_rgb = np.zeros((height, width, 3), dtype=np.uint8) 54 | # mask_rgb[mask > 0] = color 55 | # rgb_image += mask_rgb 56 | rgb_image[mask > 0] = color 57 | 58 | return rgb_image 59 | 60 | 61 | def draw_bounding_boxes(image, boxes): 62 | """ 63 | 在图像上绘制边界框 (使用 OpenCV)。 64 | 65 | :param image: RGB图像 (ndarray, shape [H, W, 3])。 66 | :param boxes: 边界框数组 (ndarray, shape [B, 4], 每行 [xmin, ymin, xmax, ymax])。 67 | :return: 带有边界框的图像。 68 | """ 69 | image_copy = image.copy() 70 | for box in boxes: 71 | xmin, ymin, xmax, ymax = map(int, box) 72 | cv2.rectangle(image_copy, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=2) # 绿色边框 73 | return image_copy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal plane instance segmentation with the Segment Anything Model 2 | 3 | This is the official PyTorch implementation for our paper "Multimodal plane instance segmentation with the Segment Anything Model". This paper has been accepted for publication in Automation in Construction. 4 | 5 | You may also learn about our algorithm from the preprint version of our paper on arXiv, titled “PlaneSAM: Multimodal Plane Instance Segmentation Using the Segment Anything Model” (https://arxiv.org/abs/2410.16545 ). 6 | 7 | However, the version we published in the journal Automation in Construction is more formal and provides a more complete description of the algorithm. 8 | 9 | 10 | ## 🔭 Introduction 11 | Abstract: Plane instance segmentation from RGB-D data is critical for BIM-related tasks. However, existing deep-learning methods rely on only RGB bands, overlooking depth information. To address this, PlaneSAM, a Segment-Anything-Model-based network, is proposed. It fully integrates RGB-D bands using a dual-complexity backbone: a simple branch primarily for the D band and a high-capacity branch mainly for RGB bands. This structure facilitates effective D-band learning with limited data, preserves EfficientSAM's RGB feature representations, and enables task-specific fine-tuning. To improve adaptability to RGB-D domains, a self-supervised pretraining strategy is introduced. EfficientSAM’s loss is also optimized for large-plane segmentation. Additionally, plane detection is performed using Faster R-CNN, enabling fully automatic segmentation. State-of-the-art performance is achieved on multiple datasets, with <10% additional overhead compared to EfficientSAM. The proposed dual-complexity backbone shows strong potential for transferring RGB-based foundation models to RGB+X domains in other scenarios, while the pretraining strategy is promising for other data-scarce tasks. 12 | 13 | ## 🔭 Citation 14 | If you find our work useful for your research, please consider citing our paper. 15 | Deng, Z., Yang, Z., Chen, C., Zeng, C., Meng, Y., Yang, B., 2025. Multimodal plane instance segmentation with the Segment Anything Model. Automation in Construction 180, 106541. 16 | 17 |
18 | -------------------------------------------------------------------------------- /PlaneSAM/backbone/vgg_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class VGG(nn.Module): 6 | def __init__(self, features, class_num=1000, init_weights=False, weights_path=None): 7 | super(VGG, self).__init__() 8 | self.features = features 9 | self.classifier = nn.Sequential( 10 | nn.Linear(512*7*7, 4096), 11 | nn.ReLU(True), 12 | nn.Dropout(p=0.5), 13 | nn.Linear(4096, 4096), 14 | nn.ReLU(True), 15 | nn.Dropout(p=0.5), 16 | nn.Linear(4096, class_num) 17 | ) 18 | if init_weights and weights_path is None: 19 | self._initialize_weights() 20 | 21 | if weights_path is not None: 22 | self.load_state_dict(torch.load(weights_path)) 23 | 24 | def forward(self, x): 25 | # N x 3 x 224 x 224 26 | x = self.features(x) 27 | # N x 512 x 7 x 7 28 | x = torch.flatten(x, start_dim=1) 29 | # N x 512*7*7 30 | x = self.classifier(x) 31 | return x 32 | 33 | def _initialize_weights(self): 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | nn.init.xavier_uniform_(m.weight) 38 | if m.bias is not None: 39 | nn.init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.Linear): 41 | nn.init.xavier_uniform_(m.weight) 42 | # nn.init.normal_(m.weight, 0, 0.01) 43 | nn.init.constant_(m.bias, 0) 44 | 45 | 46 | def make_features(cfg: list): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == "M": 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | layers += [conv2d, nn.ReLU(True)] 55 | in_channels = v 56 | return nn.Sequential(*layers) 57 | 58 | 59 | cfgs = { 60 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 61 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 62 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 63 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 64 | } 65 | 66 | 67 | def vgg(model_name="vgg16", weights_path=None): 68 | assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name) 69 | cfg = cfgs[model_name] 70 | 71 | model = VGG(make_features(cfg), weights_path=weights_path) 72 | return model 73 | -------------------------------------------------------------------------------- /PlaneSAM/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import utils.box_ops as box_ops 5 | 6 | 7 | # 输入[B, H, W]每一个像素是一个整数表示一个类 8 | # 输出[B, C, H, W]共C个类,每一个类对应一个二值掩码 9 | def build_target(target, num_classes=2, ignore_index=-100): 10 | dice_target = target.clone() 11 | if ignore_index >= 0: 12 | ignore_mask = torch.eq(target, ignore_index) 13 | dice_target[ignore_mask] = 0 14 | # [B, H, W] -> [B, H, W, C] 15 | dice_target = nn.functional.one_hot(dice_target, num_classes).float() 16 | dice_target[ignore_mask] = ignore_index 17 | else: 18 | dice_target = nn.functional.one_hot(dice_target, num_classes).float() 19 | 20 | return dice_target.permute(0, 3, 1, 2) 21 | 22 | 23 | # 输入[B, H, W] 24 | def dice_coeff(x, target, ignore_index=-100, epsilon=1e-6): 25 | d = 0. 26 | batch_size = x.shape[0] 27 | for i in range(batch_size): 28 | x_i = x[i].reshape(-1) 29 | t_i = target[i].reshape(-1) 30 | if ignore_index >= 0: 31 | # 找出像素值不为ignore_index的位置 32 | roi_mask = torch.ne(t_i, ignore_index) 33 | x_i = x_i[roi_mask] 34 | t_i = t_i[roi_mask] 35 | inter = torch.dot(x_i, t_i) 36 | sets_sum = torch.sum(x_i) + torch.sum(t_i) 37 | # 如果sets_sum为0,说明预测和实际值都为0,预测百分百正确,dice系数为1 38 | if sets_sum == 0: 39 | sets_sum = 2 * inter 40 | 41 | d += (2 * inter + epsilon) / (sets_sum + epsilon) 42 | 43 | return d / batch_size 44 | 45 | 46 | def multiclass_dice_coeff(x, target, ignore_index=-100, epsilon=1e-6): 47 | dice = 0. 48 | for channel in range(x.shape[1]): 49 | dice += dice_coeff(x[:, channel, ...], target[:, channel, ...], ignore_index, epsilon) 50 | 51 | 52 | def dice_loss(x, target, multiclass=False, ignore_index=-100): 53 | x = torch.sigmoid(x) 54 | fn = multiclass_dice_coeff if multiclass else dice_coeff 55 | return 1 - fn(x, target, ignore_index=ignore_index) 56 | 57 | 58 | def criterion(x_mask, x_iou, target, alpha=1, gamma=2): 59 | """ 60 | :param x_mask: [B, H, W]SAM输出的掩码 61 | :param x_iou: [B,]掩码对应的预测iou 62 | :param target: [B, H, W]标签 63 | :param alpha: focalloss参数 64 | :param gamma: focalloss参数 65 | :return: 返回综合损失 66 | """ 67 | batch_size, h, w = x_mask.shape 68 | binary_mask = (x_mask > 0).float() 69 | 70 | ce_loss = F.binary_cross_entropy_with_logits(x_mask, target, reduction='mean') 71 | pt = torch.exp(-ce_loss) 72 | focal_loss = alpha * (1 - pt) ** gamma * ce_loss 73 | 74 | Dice_loss = dice_loss(x_mask, target) 75 | 76 | binary_mask = binary_mask.reshape(batch_size, -1) 77 | target = target.reshape(batch_size, -1) 78 | intersection = torch.sum(torch.logical_and(binary_mask, target).float(), dim=1) 79 | union = torch.sum(torch.logical_or(binary_mask, target).float(), dim=1) 80 | iou = intersection / (union + 1e-6) 81 | binary_mask = binary_mask.reshape(batch_size, h, w) 82 | target = target.reshape(batch_size, h, w) 83 | 84 | mse_Loss = F.mse_loss(x_iou, iou, reduction='mean') 85 | 86 | return 20 * focal_loss + Dice_loss -------------------------------------------------------------------------------- /PlaneSAM/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | def box_cxcywh_to_xywh(x): 24 | x_c, y_c, w, h = x.unbind(-1) 25 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 26 | w, h] 27 | return torch.stack(b, dim=-1) 28 | 29 | 30 | def box_xyxy_to_xywh(x): 31 | x0, y0, x1, y1 = x.unbind(-1) 32 | b = [x0, y0, 33 | (x1 - x0), (y1 - y0)] 34 | return torch.stack(b, dim=-1) 35 | 36 | 37 | # modified from torchvision to also return the union 38 | def box_iou(boxes1, boxes2): 39 | area1 = box_area(boxes1) 40 | area2 = box_area(boxes2) 41 | 42 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 43 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 44 | 45 | wh = (rb - lt).clamp(min=0) # [N,M,2] 46 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 47 | 48 | union = area1[:, None] + area2 - inter 49 | 50 | iou = inter / union 51 | return iou, union 52 | 53 | 54 | def generalized_box_iou(boxes1, boxes2): 55 | """ 56 | Generalized IoU from https://giou.stanford.edu/ 57 | 58 | The boxes should be in [x0, y0, x1, y1] format 59 | 60 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 61 | and M = len(boxes2) 62 | """ 63 | # degenerate boxes gives inf / nan results 64 | # so do an early check 65 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 66 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 67 | iou, union = box_iou(boxes1, boxes2) 68 | 69 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 70 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 71 | 72 | wh = (rb - lt).clamp(min=0) # [N,M,2] 73 | area = wh[:, :, 0] * wh[:, :, 1] 74 | 75 | return iou - (area - union) / area 76 | 77 | 78 | def masks_to_boxes(masks): 79 | """Compute the bounding boxes around the provided masks 80 | 81 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 82 | 83 | Returns a [N, 4] tensors, with the boxes in xyxy format 84 | """ 85 | if masks.numel() == 0: 86 | return torch.zeros((0, 4), device=masks.device) 87 | 88 | h, w = masks.shape[-2:] 89 | 90 | y = torch.arange(0, h, dtype=torch.float, device=masks.device) 91 | x = torch.arange(0, w, dtype=torch.float, device=masks.device) 92 | y, x = torch.meshgrid(y, x) 93 | 94 | x_mask = (masks * x.unsqueeze(0)) 95 | x_max = x_mask.flatten(1).max(-1)[0] 96 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 97 | 98 | y_mask = (masks * y.unsqueeze(0)) 99 | y_max = y_mask.flatten(1).max(-1)[0] 100 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 101 | 102 | return torch.stack([x_min, y_min, x_max, y_max], 1) 103 | -------------------------------------------------------------------------------- /PlaneSAM/utils/train_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from sklearn.metrics import adjusted_rand_score 5 | 6 | 7 | def print_log(global_step, epoch, local_count, count_inter, dataset_size, loss, time_inter): 8 | print('Step: {:>5} Train Epoch: {:>3} [{:>4}/{:>4} ({:3.1f}%)] ' 9 | 'Loss: {:.6f} [{:.2f}s every {:>4} data]'.format( 10 | global_step, epoch, local_count, dataset_size, 11 | 100. * local_count / dataset_size, loss.data, time_inter, count_inter)) 12 | 13 | 14 | def save_ckpt(ckpt_dir, model, optimizer, global_step, epoch, local_count, num_train): 15 | # usually this happens only on the start of a epoch 16 | epoch_float = epoch + (local_count / num_train) 17 | state = { 18 | 'global_step': global_step, 19 | 'epoch': epoch_float, 20 | 'state_dict': model.state_dict(), 21 | 'optimizer': optimizer.state_dict(), 22 | } 23 | ckpt_model_filename = "ckpt_epoch_{:0.2f}.pth".format(epoch_float) 24 | path = os.path.join(ckpt_dir, ckpt_model_filename) 25 | torch.save(state, path) 26 | print('{:>2} has been successfully saved'.format(path)) 27 | 28 | 29 | def load_ckpt(model, optimizer, model_file, device, NUM_GPUS): 30 | if os.path.isfile(model_file): 31 | print("=> loading checkpoint '{}'".format(model_file)) 32 | if device.type == 'cuda': 33 | checkpoint = torch.load(model_file) 34 | else: 35 | checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) 36 | model_dict = checkpoint['state_dict'] 37 | model_dict_ = {} 38 | if NUM_GPUS > 1: 39 | model.load_state_dict(model_dict) 40 | else: 41 | for k, v in model_dict.items(): 42 | k_ = k.replace('module.', '') 43 | model_dict_[k_] = v 44 | model.load_state_dict(model_dict_) 45 | if optimizer: 46 | optimizer.load_state_dict(checkpoint['optimizer']) 47 | else: 48 | print("=> no checkpoint found at '{}'".format(model_file)) 49 | exit(0) 50 | 51 | 52 | # 计算指标 53 | def compute_iou(x, target): 54 | """ 55 | :param x: [B, H, W] 56 | :param target: [B, H, W] 57 | :return: iou 58 | """ 59 | b, h, w = x.shape 60 | x = x.reshape(b, -1) 61 | target = target.reshape(b, -1) 62 | intersection = torch.sum(torch.logical_and(x, target).float(), dim=1) 63 | union = torch.sum(torch.logical_or(x, target).float(), dim=1) 64 | # [B] 65 | iou = intersection / (union + 1e-9) 66 | return iou 67 | 68 | 69 | def compute_iou_sc(x, target): 70 | """ 71 | :param x: [B, H, W] 72 | :param target: [B, H, W] 73 | :return: iou 74 | """ 75 | b, h, w = x.shape 76 | x = x.reshape(b, -1) 77 | target = target.reshape(b, -1) 78 | intersection = np.sum(np.logical_and(x, target).astype(np.float32), axis=1) 79 | union = np.sum(np.logical_or(x, target).astype(np.float32), axis=1) 80 | iou = np.max(intersection / (union + 1e-9)) 81 | return iou 82 | 83 | 84 | def compute_acc(x, target): 85 | """ 86 | :param x: [B, H, W] 87 | :param target: [B, H, W] 88 | :return: acc 89 | """ 90 | b, h, w = x.shape 91 | x = x.reshape(b, -1) 92 | target = target.reshape(b, -1) 93 | 94 | acc = torch.mean((x == target).float(), dim=1) 95 | return acc 96 | 97 | 98 | def compute_RI(x, target): 99 | """ 100 | :param x: [B, H, W] 101 | :param target: [B, H, W] 102 | :return: RI 103 | """ 104 | num_planes, h, w = x.shape 105 | pred_map = np.zeros((h, w), dtype=np.uint8) 106 | gt_map = np.zeros((h, w), dtype=np.uint8) 107 | 108 | # [H, W] 109 | for i in range(num_planes): 110 | pred_map[x[i] == 1] = i + 1 111 | gt_map[target[i] == 1] = i + 1 112 | 113 | pred_map = pred_map.flatten() 114 | gt_map = gt_map.flatten() 115 | RI = adjusted_rand_score(pred_map, gt_map) 116 | 117 | return RI 118 | 119 | 120 | def compute_SC(x, target): 121 | """ 122 | :param x: [B, H, W] 123 | :param target: [B, H, W] 124 | :return: RI 125 | """ 126 | iou_1 = [] 127 | iou_2 = [] 128 | num_planes, h, w = x.shape 129 | for per_plane in x: 130 | plane = np.repeat(per_plane[np.newaxis, ...], num_planes, axis=0) 131 | iou_1.append(compute_iou_sc(plane, target)) 132 | iou_1 = sum(iou_1) / len(iou_1) 133 | 134 | for per_plane in target: 135 | plane = np.repeat(per_plane[np.newaxis, ...], num_planes, axis=0) 136 | iou_2.append(compute_iou_sc(plane, x)) 137 | iou_2 = sum(iou_2) / len(iou_2) 138 | 139 | return (iou_1 + iou_2) / 2 -------------------------------------------------------------------------------- /PlaneSAM/utils/eval_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import adjusted_rand_score, mutual_info_score 4 | from torchvision.ops import box_iou 5 | 6 | 7 | def compute_iou(x, target): 8 | """ 9 | :param x: [B, H, W] 10 | :param target: [B, H, W] 11 | :return: iou 12 | """ 13 | b, h, w = x.shape 14 | x = x.reshape(b, -1) 15 | target = target.reshape(b, -1) 16 | intersection = torch.sum(torch.logical_and(x, target).float(), dim=1) 17 | union = torch.sum(torch.logical_or(x, target).float(), dim=1) 18 | # [B] 19 | iou = intersection / (union + 1e-9) 20 | return iou 21 | 22 | 23 | def compute_iou_sc(x, target): 24 | """ 25 | :param x: [B, H, W] 26 | :param target: [B, H, W] 27 | :return: iou 28 | """ 29 | b, h, w = x.shape 30 | x = x.reshape(b, -1) 31 | target = target.reshape(b, -1) 32 | intersection = np.sum(np.logical_and(x, target).astype(np.float32), axis=1) 33 | union = np.sum(np.logical_or(x, target).astype(np.float32), axis=1) 34 | iou = np.max(intersection / (union + 1e-9)) 35 | return iou 36 | 37 | 38 | def compute_acc(x, target): 39 | """ 40 | :param x: [B, H, W] 41 | :param target: [B, H, W] 42 | :return: acc 43 | """ 44 | b, h, w = x.shape 45 | x = x.reshape(b, -1) 46 | target = target.reshape(b, -1) 47 | 48 | acc = torch.mean((x == target).float(), dim=1) 49 | return acc 50 | 51 | 52 | def masks_to_map(masks): 53 | """ 54 | :param masks: [B, H, W] 55 | :return: [H, W] 56 | """ 57 | num_planes, h, w = masks.shape 58 | semantic_map = np.zeros((h, w), dtype=np.uint8) 59 | # [H, W] 60 | for i in range(num_planes): 61 | semantic_map[masks[i] == 1] = i + 1 62 | 63 | return semantic_map 64 | 65 | 66 | def match_boxes_gt(pred_boxes, gt_boxes): 67 | # pred_boxes has been sorted by score 68 | num_gts = gt_boxes.shape[0] 69 | num_preds = pred_boxes.shape[0] 70 | device = gt_boxes.device 71 | batched_points = torch.full((num_gts, 2, 2), fill_value=-1, dtype=torch.float32, device=device) 72 | batched_labels = torch.full((num_gts, 2), fill_value=-1, dtype=torch.float32, device=device) 73 | iou_matrix = box_iou(pred_boxes, gt_boxes) 74 | 75 | _, best_match_indices = iou_matrix.max(dim=1) 76 | unique_masks = torch.zeros(num_preds, dtype=torch.bool, device=device) 77 | tmp = [] 78 | for i, t in enumerate(best_match_indices): 79 | if t not in tmp: 80 | tmp.append(t) 81 | unique_masks[i] = True 82 | 83 | pred_boxes = pred_boxes[unique_masks] 84 | best_match_indices = best_match_indices[unique_masks] 85 | 86 | batched_points[best_match_indices, 0] = pred_boxes[:, :2] 87 | batched_points[best_match_indices, 1] = pred_boxes[:, 2:] 88 | batched_labels[best_match_indices, 0] = 2 89 | batched_labels[best_match_indices, 1] = 3 90 | 91 | return batched_points.unsqueeze(1), batched_labels.unsqueeze(1) 92 | 93 | 94 | def evaluateMasks(predMasks, gtMasks): 95 | """ 96 | :param predMasks: [N, H, W] 97 | :param gtMasks: [N, H, W] 98 | :return: 99 | """ 100 | valid_mask = (gtMasks.max(0)[0]).unsqueeze(0) 101 | 102 | gtMasks = torch.cat([gtMasks, torch.clamp(1 - gtMasks.sum(0, keepdim=True), min=0)], dim=0) # M+1, H, W 103 | predMasks = torch.cat([predMasks, torch.clamp(1 - predMasks.sum(0, keepdim=True), min=0)], dim=0) # N+1, H, W 104 | 105 | intersection = (gtMasks.unsqueeze(1) * predMasks * valid_mask).sum(-1).sum(-1).float() 106 | union = (torch.max(gtMasks.unsqueeze(1), predMasks) * valid_mask).sum(-1).sum(-1).float() 107 | 108 | N = intersection.sum() 109 | 110 | RI = 1 - ((intersection.sum(0).pow(2).sum() + intersection.sum(1).pow(2).sum()) / 2 - intersection.pow(2).sum()) / ( 111 | N * (N - 1) / 2) 112 | joint = intersection / N 113 | marginal_2 = joint.sum(0) 114 | marginal_1 = joint.sum(1) 115 | H_1 = (-marginal_1 * torch.log2(marginal_1 + (marginal_1 == 0).float())).sum() 116 | H_2 = (-marginal_2 * torch.log2(marginal_2 + (marginal_2 == 0).float())).sum() 117 | 118 | B = (marginal_1.unsqueeze(-1) * marginal_2) 119 | log2_quotient = torch.log2(torch.clamp(joint, 1e-8) / torch.clamp(B, 1e-8)) * (torch.min(joint, B) > 1e-8).float() 120 | MI = (joint * log2_quotient).sum() 121 | voi = H_1 + H_2 - 2 * MI 122 | 123 | IOU = intersection / torch.clamp(union, min=1) 124 | SC = ((IOU.max(-1)[0] * torch.clamp((gtMasks * valid_mask).sum(-1).sum(-1), min=1e-4)).sum() / N + ( 125 | IOU.max(0)[0] * torch.clamp((predMasks * valid_mask).sum(-1).sum(-1), min=1e-4)).sum() / N) / 2 126 | info = [RI.item(), voi.item(), SC.item()] 127 | 128 | return info -------------------------------------------------------------------------------- /PlaneSAM/backbone/mobilenetv2_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision.ops import misc 4 | 5 | 6 | def _make_divisible(ch, divisor=8, min_ch=None): 7 | """ 8 | This function is taken from the original tf repo. 9 | It ensures that all layers have a channel number that is divisible by 8 10 | It can be seen here: 11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 12 | """ 13 | if min_ch is None: 14 | min_ch = divisor 15 | new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor) 16 | # Make sure that round down does not go down by more than 10%. 17 | if new_ch < 0.9 * ch: 18 | new_ch += divisor 19 | return new_ch 20 | 21 | 22 | class ConvBNReLU(nn.Sequential): 23 | def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, groups=1, norm_layer=None): 24 | padding = (kernel_size - 1) // 2 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | super(ConvBNReLU, self).__init__( 28 | nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, groups=groups, bias=False), 29 | norm_layer(out_channel), 30 | nn.ReLU6(inplace=True) 31 | ) 32 | 33 | 34 | class InvertedResidual(nn.Module): 35 | def __init__(self, in_channel, out_channel, stride, expand_ratio, norm_layer=None): 36 | super(InvertedResidual, self).__init__() 37 | hidden_channel = in_channel * expand_ratio 38 | self.use_shortcut = stride == 1 and in_channel == out_channel 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | 42 | layers = [] 43 | if expand_ratio != 1: 44 | # 1x1 pointwise conv 45 | layers.append(ConvBNReLU(in_channel, hidden_channel, kernel_size=1, norm_layer=norm_layer)) 46 | layers.extend([ 47 | # 3x3 depthwise conv 48 | ConvBNReLU(hidden_channel, hidden_channel, stride=stride, groups=hidden_channel, norm_layer=norm_layer), 49 | # 1x1 pointwise conv(linear) 50 | nn.Conv2d(hidden_channel, out_channel, kernel_size=1, bias=False), 51 | norm_layer(out_channel), 52 | ]) 53 | 54 | self.conv = nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | if self.use_shortcut: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class MobileNetV2(nn.Module): 64 | def __init__(self, num_classes=1000, alpha=1.0, round_nearest=8, weights_path=None, norm_layer=None): 65 | super(MobileNetV2, self).__init__() 66 | block = InvertedResidual 67 | input_channel = _make_divisible(32 * alpha, round_nearest) 68 | last_channel = _make_divisible(1280 * alpha, round_nearest) 69 | 70 | if norm_layer is None: 71 | norm_layer = nn.BatchNorm2d 72 | 73 | inverted_residual_setting = [ 74 | # t, c, n, s 75 | [1, 16, 1, 1], 76 | [6, 24, 2, 2], 77 | [6, 32, 3, 2], 78 | [6, 64, 4, 2], 79 | [6, 96, 3, 1], 80 | [6, 160, 3, 2], 81 | [6, 320, 1, 1], 82 | ] 83 | 84 | features = [] 85 | # conv1 layer 86 | features.append(ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)) 87 | # building inverted residual residual blockes 88 | for t, c, n, s in inverted_residual_setting: 89 | output_channel = _make_divisible(c * alpha, round_nearest) 90 | for i in range(n): 91 | stride = s if i == 0 else 1 92 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 93 | input_channel = output_channel 94 | # building last several layers 95 | features.append(ConvBNReLU(input_channel, last_channel, 1, norm_layer=norm_layer)) 96 | # combine feature layers 97 | self.features = nn.Sequential(*features) 98 | 99 | # building classifier 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | self.classifier = nn.Sequential( 102 | nn.Dropout(0.2), 103 | nn.Linear(last_channel, num_classes) 104 | ) 105 | 106 | if weights_path is None: 107 | # weight initialization 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 111 | if m.bias is not None: 112 | nn.init.zeros_(m.bias) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | nn.init.ones_(m.weight) 115 | nn.init.zeros_(m.bias) 116 | elif isinstance(m, nn.Linear): 117 | nn.init.normal_(m.weight, 0, 0.01) 118 | nn.init.zeros_(m.bias) 119 | else: 120 | self.load_state_dict(torch.load(weights_path)) 121 | 122 | def forward(self, x): 123 | x = self.features(x) 124 | x = self.avgpool(x) 125 | x = torch.flatten(x, 1) 126 | x = self.classifier(x) 127 | return x 128 | -------------------------------------------------------------------------------- /PlaneSAM/S2D3DSDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import random 5 | import numpy as np 6 | from torchvision import transforms 7 | import skimage 8 | from pycocotools.coco import COCO 9 | from PIL import Image 10 | 11 | class S2d3dsDataset(Dataset): 12 | def __init__(self, datafolder, subset='train', transform=None): 13 | self.datafolder = datafolder 14 | self.subset = subset 15 | self.transform = transform 16 | assert subset in ['train', 'val'] 17 | if subset == 'train': 18 | self.coco = COCO(annotation_file=os.path.join(datafolder, 's2d3ds_train.json')) 19 | else: 20 | self.coco = COCO(annotation_file=os.path.join(datafolder, 's2d3ds_val.json')) 21 | self.image_ids = self.coco.getImgIds() 22 | 23 | def __len__(self): 24 | return len(self.image_ids) 25 | 26 | ''' 27 | return: 28 | image: [3, H, W] 29 | depth: [H, W] 30 | segmentation: [H, W] 31 | instance: [num_planes, H, W] 32 | ''' 33 | 34 | def __getitem__(self, index): 35 | image_id = self.image_ids[index] 36 | if self.subset == 'train': 37 | image_path = os.path.join(self.datafolder, 'images', self.coco.loadImgs(image_id)[0]['file_name']) 38 | else: 39 | image_path = os.path.join(self.datafolder, 'images_val', self.coco.loadImgs(image_id)[0]['file_name']) 40 | depth_path = image_path.replace('images', 'depths').replace('rgb', 'depth').replace('.jpg', '.png') 41 | image = np.array(Image.open(image_path), dtype=np.uint8) 42 | depth = np.array(Image.open(depth_path), dtype=np.float32) / 1000.0 43 | 44 | segmentation = np.zeros(image.shape[:2], dtype=np.uint8) 45 | annotation_id = self.coco.getAnnIds(imgIds=image_id) 46 | annotation = self.coco.loadAnns(annotation_id) 47 | for idx, i in enumerate(annotation): 48 | segmentation[self.coco.annToMask(i) > 0] = idx + 1 49 | idx += 1 50 | 51 | sample = {} 52 | if self.transform: 53 | sample = self.transform({ 54 | 'image': image, 55 | 'depth': depth, 56 | 'segmentation': segmentation 57 | }) 58 | image = sample['image'] 59 | depth = sample['depth'] 60 | segmentation = sample['segmentation'] 61 | 62 | mask = [] 63 | unique_idx = torch.unique(segmentation) 64 | unique_idx = [x for x in unique_idx if x] 65 | for i in unique_idx: 66 | mask.append((segmentation == i).float()) 67 | mask = torch.stack(mask) 68 | bbox = self.masks_to_bboxes(mask) 69 | num_planes = len(unique_idx) 70 | 71 | masks = torch.zeros(30, image.shape[1], image.shape[2], dtype=torch.float32) 72 | masks[:num_planes] = mask 73 | 74 | sample.update({ 75 | 'instance': masks, 76 | 'num_planes': num_planes, 77 | "data_path": image_path 78 | }) 79 | 80 | return sample 81 | 82 | def masks_to_bboxes(self, masks): 83 | """ 84 | 从掩码张量中计算边界框的左上和右下坐标 85 | 参数: 86 | masks: 形状为 [B, H, W] 的二进制掩码张量 87 | 返回值: 88 | bounding_boxes: 形状为 [B, 4] 的边界框坐标张量,包含左上和右下坐标 89 | """ 90 | batch_size, height, width = masks.size() 91 | device = masks 92 | bounding_boxes = torch.zeros((batch_size, 4), dtype=torch.float32) 93 | 94 | for b in range(batch_size): 95 | mask = masks[b] 96 | 97 | # 找到掩码的非零元素索引 98 | nonzero_indices = torch.nonzero(mask) 99 | 100 | if nonzero_indices.size(0) == 0: 101 | # 如果掩码中没有非零元素,则边界框坐标为零 102 | assert "no mask!" 103 | else: 104 | # 计算边界框的左上和右下坐标 105 | ymin = torch.min(nonzero_indices[:, 0]) 106 | xmin = torch.min(nonzero_indices[:, 1]) 107 | ymax = torch.max(nonzero_indices[:, 0]) 108 | xmax = torch.max(nonzero_indices[:, 1]) 109 | bounding_boxes[b] = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32) 110 | 111 | return bounding_boxes 112 | 113 | 114 | class ToTensor(object): 115 | def __call__(self, sample): 116 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 117 | # [H, W, C] -> [C, H, W], 像素值归一化到0-1之间 118 | image = transforms.ToTensor()(image) 119 | # [1, H, W] 120 | depth = transforms.ToTensor()(depth) 121 | return { 122 | 'image': image, 123 | 'depth': depth, 124 | 'segmentation': torch.from_numpy(segmentation.astype(np.int16)).float() 125 | } 126 | 127 | class RandomFlip(object): 128 | def __call__(self, sample): 129 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 130 | if random.random() > 0.5: 131 | image = np.fliplr(image).copy() 132 | depth = np.fliplr(depth).copy() 133 | segmentation = np.fliplr(segmentation).copy() 134 | 135 | return { 136 | 'image': image, 137 | 'depth': depth, 138 | 'segmentation': segmentation 139 | } -------------------------------------------------------------------------------- /PlaneSAM/Nyuv2Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import random 5 | import numpy as np 6 | from torchvision import transforms 7 | import skimage 8 | 9 | class Nyuv2Dataset(Dataset): 10 | def __init__(self, datafolder, subset='train', transform=None): 11 | self.datafolder = datafolder 12 | self.subset = subset 13 | self.transform = transform 14 | assert subset in ['train', 'val'] 15 | self.data_list = [os.path.join(datafolder, subset, x) for x in os.listdir(os.path.join(datafolder, subset))] 16 | 17 | def __len__(self): 18 | return len(self.data_list) 19 | 20 | ''' 21 | return: 22 | image: [3, H, W] 23 | depth: [H, W] 24 | segmentation: [H, W] 25 | instance: [num_planes, H, W] 26 | ''' 27 | def __getitem__(self, index): 28 | data_path = self.data_list[index] 29 | data = np.load(data_path, allow_pickle=True) 30 | data = np.load(self.data_list[index], allow_pickle=True) 31 | image = data[:, :, :3].astype(np.uint8) 32 | depth = data[:, :, 3] 33 | segmentation = data[:, :, 4].astype(np.uint8) 34 | 35 | sample = {} 36 | if self.transform: 37 | sample = self.transform({ 38 | 'image': image, 39 | 'depth': depth, 40 | 'segmentation': segmentation 41 | }) 42 | image = sample['image'] 43 | depth = sample['depth'] 44 | segmentation = sample['segmentation'] 45 | 46 | mask = [] 47 | unique_idx = torch.unique(segmentation) 48 | unique_idx = [x for x in unique_idx if x] 49 | for i in unique_idx: 50 | mask.append((segmentation == i).float()) 51 | mask = torch.stack(mask) 52 | bbox = self.masks_to_bboxes(mask) 53 | num_planes = len(unique_idx) 54 | 55 | masks = torch.zeros(30, image.shape[1], image.shape[2], dtype=torch.float32) 56 | masks[:num_planes] = mask 57 | 58 | sample.update({ 59 | 'instance': masks, 60 | 'num_planes': num_planes, 61 | 'data_path': data_path 62 | }) 63 | 64 | return sample 65 | 66 | def masks_to_bboxes(self, masks): 67 | """ 68 | 从掩码张量中计算边界框的左上和右下坐标 69 | 参数: 70 | masks: 形状为 [B, H, W] 的二进制掩码张量 71 | 返回值: 72 | bounding_boxes: 形状为 [B, 4] 的边界框坐标张量,包含左上和右下坐标 73 | """ 74 | batch_size, height, width = masks.size() 75 | device = masks 76 | bounding_boxes = torch.zeros((batch_size, 4), dtype=torch.float32) 77 | 78 | for b in range(batch_size): 79 | mask = masks[b] 80 | 81 | # 找到掩码的非零元素索引 82 | nonzero_indices = torch.nonzero(mask) 83 | 84 | if nonzero_indices.size(0) == 0: 85 | # 如果掩码中没有非零元素,则边界框坐标为零 86 | assert "no mask!" 87 | else: 88 | # 计算边界框的左上和右下坐标 89 | ymin = torch.min(nonzero_indices[:, 0]) 90 | xmin = torch.min(nonzero_indices[:, 1]) 91 | ymax = torch.max(nonzero_indices[:, 0]) 92 | xmax = torch.max(nonzero_indices[:, 1]) 93 | bounding_boxes[b] = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32) 94 | 95 | return bounding_boxes 96 | 97 | 98 | class ToTensor(object): 99 | def __call__(self, sample): 100 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 101 | # [H, W, C] -> [C, H, W], 像素值归一化到0-1之间 102 | image = transforms.ToTensor()(image) 103 | # [1, H, W] 104 | depth = transforms.ToTensor()(depth) 105 | return { 106 | 'image': image, 107 | 'depth': depth, 108 | 'segmentation': torch.from_numpy(segmentation.astype(np.int16)).float() 109 | } 110 | 111 | class RandomFlip(object): 112 | def __call__(self, sample): 113 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 114 | if random.random() > 0.5: 115 | image = np.fliplr(image).copy() 116 | depth = np.fliplr(depth).copy() 117 | segmentation = np.fliplr(segmentation).copy() 118 | 119 | return { 120 | 'image': image, 121 | 'depth': depth, 122 | 'segmentation': segmentation 123 | } 124 | 125 | # 随机缩放 126 | class RandomScale(object): 127 | def __init__(self, scale): 128 | self.scale_low = min(scale) 129 | self.scale_high = max(scale) 130 | 131 | def __call__(self, sample): 132 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 133 | 134 | target_scale = random.uniform(self.scale_low, self.scale_high) 135 | # (H, W, C) 136 | target_height = int(round(target_scale * image.shape[0])) 137 | target_width = int(round(target_scale * image.shape[1])) 138 | # Bi-linear 139 | image = skimage.transform.resize(image, (target_height, target_width), 140 | order=1, mode='reflect', preserve_range=True).astype(np.uint8) 141 | # Nearest-neighbor 142 | depth = skimage.transform.resize(depth, (target_height, target_width), 143 | order=0, mode='reflect', preserve_range=True).astype(np.uint8) 144 | segmentation = skimage.transform.resize(segmentation, (target_height, target_width), 145 | order=0, mode='reflect', preserve_range=True) 146 | 147 | return {'image': image, 'depth': depth, 'segmentation': segmentation} 148 | 149 | # 随机裁剪 150 | class RandomCrop(object): 151 | def __init__(self, th, tw): 152 | self.th = th 153 | self.tw = tw 154 | 155 | def __call__(self, sample): 156 | image, depth, segmentation = sample['image'], sample['depth'], sample['segmentation'] 157 | h = image.shape[0] 158 | w = image.shape[1] 159 | i = random.randint(0, h - self.th) 160 | j = random.randint(0, w - self.tw) 161 | 162 | return {'image': image[i:i + image_h, j:j + image_w, :], 163 | 'depth': depth[i:i + image_h, j:j + image_w], 164 | 'segmentation': segmentation[i:i + image_h, j:j + image_w]} -------------------------------------------------------------------------------- /PlaneSAM/utils/make_prompt.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | def make_point_pt(masks, num_samples=1): 5 | """ 6 | masks: [B, H, W] 7 | 8 | return: 9 | batched_points: [B, num_pt, per_pt_points, 2] 10 | batched_label: [B, num_pt, per_pt_points] 11 | """ 12 | device = masks.device 13 | batched_points = [] 14 | batched_labels = [] 15 | 16 | for mask in masks: 17 | # instance_ids, counts = torch.unique(mask, return_counts=True) 18 | index = torch.nonzero(mask, as_tuple=True) 19 | num_points = len(index[0]) 20 | if num_points >= num_samples: 21 | # diff = num_samples - num_points if num_samples > num_points else 0 22 | # num_samples = num_points if num_samples > num_points else num_samples 23 | select_idx = torch.tensor(random.sample(range(num_points), num_samples)) 24 | # [num_samples, 2] 25 | point_coords = torch.zeros((num_samples, 2)) 26 | # h -> y and w -> x 27 | point_coords[..., 1], point_coords[..., 0] = index[0][select_idx], index[1][select_idx] 28 | point_coords = point_coords[None, ...] 29 | # 前景是1,背景为0 30 | point_labels = torch.ones(1, num_samples) 31 | 32 | # 如果diff存在,就要确保batch的格式一致 33 | # if diff: 34 | # diff_point_coords = torch.full((1, diff, 2), fill_value=-1) 35 | # diff_point_labels = torch.full((1, diff), fill_value=-1) 36 | # point_coords = torch.cat((point_coords, diff_point_coords), dim=1) 37 | # point_labels = torch.cat((point_labels, diff_point_labels), dim=1) 38 | 39 | # 没有实例 40 | else: 41 | # 点提示全部用-1填充 42 | point_coords = torch.full((1, num_samples, 2), fill_value=-1) 43 | point_labels = torch.full((1, num_samples), fill_value=-1) 44 | 45 | batched_points.append(point_coords) 46 | batched_labels.append(point_labels) 47 | 48 | batched_points = torch.stack(batched_points, dim=0).to(device) 49 | batched_labels = torch.stack(batched_labels, dim=0).to(device) 50 | 51 | return batched_points, batched_labels 52 | 53 | 54 | def make_box_pt(batched_masks, noise_ratio=0.): 55 | """ 56 | batched_mask: [B, H, W] 57 | 58 | return: 59 | 左上点的标签是2,右下点的标签是3 60 | batched_points: [B, num_pt, per_pt_points, 2] 61 | batched_label: [B, num_pt, per_pt_points] 62 | """ 63 | device = batched_masks.device 64 | 65 | batched_points = [] 66 | batched_labels = [] 67 | 68 | for per_mask in batched_masks: 69 | indices = torch.nonzero(per_mask) 70 | h, w = per_mask.shape 71 | 72 | # 存在前景 73 | if indices.numel(): 74 | # 计算边界框的左下和右上点的坐标 75 | # bbox_min:[x_min, y_min] 76 | bbox_min = torch.min(indices, dim=0).values.flip(0) 77 | bbox_max = torch.max(indices, dim=0).values.flip(0) 78 | 79 | # 加10%边界框长度的噪声 80 | bbox_w, bbox_h = bbox_max - bbox_min 81 | noise_x1 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_w) 82 | noise_y1 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_h) 83 | noise_x2 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_w) 84 | noise_y2 = int(random.uniform(-noise_ratio, noise_ratio) * bbox_h) 85 | bbox_min[0] = bbox_min[0] + noise_x1 86 | bbox_min[1] = bbox_min[1] + noise_y1 87 | bbox_max[0] = bbox_max[0] + noise_x2 88 | bbox_max[1] = bbox_max[1] + noise_y2 89 | bbox_min[0] = torch.clip(bbox_min[0], 0, w) 90 | bbox_min[1] = torch.clip(bbox_min[1], 0, h) 91 | bbox_max[0] = torch.clip(bbox_max[0], 0, w) 92 | bbox_max[1] = torch.clip(bbox_max[1], 0, h) 93 | 94 | # [num_pt, 2, 2] 95 | per_box_pt = torch.cat((bbox_min[None, None, :], bbox_max[None, None, :]), dim=1) 96 | 97 | # 左上对应最小点,右下对应最大点 98 | bottomright_label = torch.full((1, 1), fill_value=3).to(device) 99 | topleft_label = torch.full((1, 1), fill_value=2).to(device) 100 | # [1, 2, 1] 101 | per_box_label = torch.cat((topleft_label, bottomright_label), dim=1) 102 | 103 | batched_points.append(per_box_pt) 104 | batched_labels.append(per_box_label) 105 | else: 106 | batched_points.append(torch.full((1, 2, 2), fill_value=-1).to(device)) 107 | batched_labels.append(torch.full((1, 2), fill_value=-1).to(device)) 108 | 109 | batched_points = torch.stack(batched_points, dim=0) 110 | batched_labels = torch.stack(batched_labels, dim=0) 111 | 112 | return batched_points, batched_labels 113 | 114 | 115 | def preprocess(masks, num_points=1, box=False, noise_ratio=0.): 116 | """ 117 | masks: [B, H, W] 118 | 119 | return: 120 | point_prompts: [B, num_pts, num_points, 2] 121 | prompt_labels: [B, num_pts, num_points] 122 | """ 123 | assert 0 <= num_points + 2 * box <= 6, "num_points shouldn't be greater than 6 or less than 0!" 124 | 125 | device = masks.device 126 | 127 | batched_points = None 128 | batched_pt_labels = None 129 | batched_boxes = None 130 | batched_box_labels = None 131 | 132 | if num_points: 133 | batched_points, batched_pt_labels = make_point_pt(masks, num_samples=num_points) 134 | if box: 135 | batched_boxes, batched_box_labels = make_box_pt(masks, noise_ratio) 136 | 137 | # 分配到指定设备上 138 | if batched_points is not None: 139 | batched_points = batched_points.to(device) 140 | batched_pt_labels = batched_pt_labels.to(device) 141 | if batched_boxes is not None: 142 | batched_boxes = batched_boxes.to(device) 143 | batched_box_labels = batched_box_labels.to(device) 144 | 145 | # 分情况返回值 146 | if batched_points is not None and batched_boxes is not None: 147 | # [B, num_pts, num_points, 2] 148 | # [B, num_pts, num_points] 149 | return torch.cat((batched_boxes, batched_points), dim=2), torch.cat( 150 | (batched_box_labels, batched_pt_labels), dim=2) 151 | elif batched_points is not None: 152 | return batched_points, batched_pt_labels 153 | elif batched_boxes is not None: 154 | return batched_boxes, batched_box_labels -------------------------------------------------------------------------------- /PlaneSAM/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim 9 | import torchvision.transforms as transforms 10 | from model import build_efficient_sam_vitt 11 | from FasterRCNN.build_FasterRCNN import create_model 12 | from PlaneDataset import PlaneDataset 13 | from Nyuv2Dataset import Nyuv2Dataset, ToTensor 14 | from S2D3DSDataset import S2d3dsDataset, ToTensor 15 | from torch.utils.data import DataLoader 16 | from utils.make_prompt import preprocess 17 | from utils.train_tools import load_ckpt 18 | from utils.eval_tools import evaluateMasks, match_boxes_gt 19 | from utils.box_ops import masks_to_boxes 20 | 21 | parser = argparse.ArgumentParser(description='Segment Any Planes') 22 | parser.add_argument('--data-dir', default='ScanNet', metavar='DIR', 23 | help='path to dataset') 24 | parser.add_argument('--num-workers', default=8, type=int, metavar='N', 25 | help='number of data loading workers (default: 8)') 26 | parser.add_argument('-o', '--output', default='result', metavar='DIR', 27 | help='path to output') 28 | parser.add_argument('--cuda', action='store_true', default=True, 29 | help='enables CUDA training') 30 | parser.add_argument('--last-ckpt', default='model/PlaneSAM.pth', type=str, metavar='PATH', 31 | help='path to latest checkpoint (default: none)') 32 | parser.add_argument('--detector-ckpt', default='weights/FasterRCNN.pth', type=str, metavar='PATH', 33 | help='path to detector checkpoint (default: none)') 34 | 35 | args = parser.parse_args() 36 | device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu") 37 | 38 | 39 | def inference(): 40 | model = build_efficient_sam_vitt() 41 | detector = create_model(num_classes=2, load_pretrain_weights=False) 42 | detector.load_state_dict(torch.load(args.detector_ckpt)['model']) 43 | 44 | # if torch.cuda.device_count() > 1: 45 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 46 | # model = nn.DataParallel(model) 47 | # detector = nn.DataParallel(detector) 48 | 49 | model.to(device) 50 | detector.to(device) 51 | 52 | # eval scannet 53 | val_data = PlaneDataset(subset="val", 54 | transform=transforms.Compose([ 55 | transforms.ToTensor()]), 56 | root_dir=args.data_dir 57 | ) 58 | val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 59 | pin_memory=False) 60 | 61 | # eval mp3d and synthetic 62 | # val_data = Nyuv2Dataset(subset="val", 63 | # transform=transforms.Compose([ 64 | # ToTensor()]), 65 | # datafolder=args.data_dir 66 | # ) 67 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 68 | # pin_memory=False) 69 | 70 | # eval s2d3ds 71 | # val_data = S2d3dsDataset(subset="val", 72 | # transform=transforms.Compose([ 73 | # ToTensor()]), 74 | # datafolder=args.data_dir 75 | # ) 76 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 77 | # pin_memory=False) 78 | 79 | if args.last_ckpt: 80 | load_ckpt(model, None, args.last_ckpt, device, NUM_GPUS=1) 81 | else: 82 | print('no ckpt!') 83 | 84 | mRI = .0 85 | mSC = .0 86 | mVoI = .0 87 | num = 0 88 | total_time = 0 89 | model.eval() 90 | detector.eval() 91 | with torch.no_grad(): 92 | for batch_idx, sample in enumerate(val_loader): 93 | num_planes = sample['num_planes'][0] 94 | image = sample['image'].to(device) 95 | target = sample['instance'].to(device) 96 | target = target[:, :num_planes, :, :].permute(1, 0, 2, 3) # [num_planes, B, H, W] 97 | gt_boxes = masks_to_boxes(target) 98 | depth = sample['depth'].to(device) 99 | 100 | # use gt box 101 | # batch_points, batch_labels = preprocess(target.flatten(0, 1), num_points=0, box=True) 102 | 103 | # 预测边界框 to prompt 104 | start_time = time.time() 105 | outputs = detector(image)[0] 106 | boxes, scores = outputs['boxes'], outputs['scores'] 107 | batch_points, batch_labels = match_boxes_gt(boxes, gt_boxes) 108 | 109 | # union prompt 110 | pred_masks, pred_ious = model(image, depth, batch_points, batch_labels) 111 | end_time = time.time() 112 | total_time += end_time - start_time 113 | pred_masks = pred_masks.permute(1, 0, 2, 3, 4) 114 | pred_ious = pred_ious.permute(1, 0, 2) 115 | 116 | # # use preprocess 117 | # batch_points, batch_labels = preprocess(target.flatten(0, 1), num_points=0, box=True) 118 | 119 | # 二值掩码 120 | pred_masks = (pred_masks > 0.).float() 121 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 122 | # [3, B, H, W] 123 | # [3, B] 124 | pred_masks = pred_masks.view(b, -1, h, w).permute(1, 0, 2, 3) 125 | pred_ious = pred_ious.view(b, -1).permute(1, 0) 126 | best_id = torch.argmax(pred_ious, dim=0) 127 | # [num_preds, H, W] 128 | best_mask = pred_masks[best_id, torch.arange(b)] 129 | target = target.squeeze(1) 130 | 131 | # 计算指标 132 | RI, VoI, SC = evaluateMasks(best_mask, target) 133 | mRI += RI 134 | mSC += SC 135 | mVoI += VoI 136 | num += 1 137 | 138 | print(f"iter: {batch_idx} mRI: {RI:.4f} mSC: {SC:.4f} VoI: {VoI:.4f}") 139 | 140 | mRI = mRI / num 141 | mSC = mSC / num 142 | mVoI = mVoI / num 143 | print(f"mRI: {mRI:.3f} mSC: {mSC:.3f} VoI: {mVoI:.3f}") 144 | print("img/s:", len(val_loader) / total_time) 145 | 146 | 147 | if __name__ == '__main__': 148 | if not os.path.exists(args.output): 149 | os.mkdir(args.output) 150 | inference() 151 | -------------------------------------------------------------------------------- /PlaneSAM/FasterRCNN/boxes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | from torch import Tensor 4 | import torchvision 5 | 6 | 7 | def nms(boxes, scores, iou_threshold): 8 | # type: (Tensor, Tensor, float) -> Tensor 9 | """ 10 | Performs non-maximum suppression (NMS) on the boxes according 11 | to their intersection-over-union (IoU). 12 | 13 | NMS iteratively removes lower scoring boxes which have an 14 | IoU greater than iou_threshold with another (higher scoring) 15 | box. 16 | 17 | Parameters 18 | ---------- 19 | boxes : Tensor[N, 4]) 20 | boxes to perform NMS on. They 21 | are expected to be in (x1, y1, x2, y2) format 22 | scores : Tensor[N] 23 | scores for each one of the boxes 24 | iou_threshold : float 25 | discards all overlapping 26 | boxes with IoU > iou_threshold 27 | 28 | Returns 29 | ------- 30 | keep : Tensor 31 | int64 tensor with the indices 32 | of the elements that have been kept 33 | by NMS, sorted in decreasing order of scores 34 | """ 35 | return torch.ops.torchvision.nms(boxes, scores, iou_threshold) 36 | 37 | 38 | def batched_nms(boxes, scores, idxs, iou_threshold): 39 | # type: (Tensor, Tensor, Tensor, float) -> Tensor 40 | """ 41 | Performs non-maximum suppression in a batched fashion. 42 | 43 | Each index value correspond to a category, and NMS 44 | will not be applied between elements of different categories. 45 | 46 | Parameters 47 | ---------- 48 | boxes : Tensor[N, 4] 49 | boxes where NMS will be performed. They 50 | are expected to be in (x1, y1, x2, y2) format 51 | scores : Tensor[N] 52 | scores for each one of the boxes 53 | idxs : Tensor[N] 54 | indices of the categories for each one of the boxes. 55 | iou_threshold : float 56 | discards all overlapping boxes 57 | with IoU < iou_threshold 58 | 59 | Returns 60 | ------- 61 | keep : Tensor 62 | int64 tensor with the indices of 63 | the elements that have been kept by NMS, sorted 64 | in decreasing order of scores 65 | """ 66 | if boxes.numel() == 0: 67 | return torch.empty((0,), dtype=torch.int64, device=boxes.device) 68 | 69 | # strategy: in order to perform NMS independently per class. 70 | # we add an offset to all the boxes. The offset is dependent 71 | # only on the class idx, and is large enough so that boxes 72 | # from different classes do not overlap 73 | # 获取所有boxes中最大的坐标值(xmin, ymin, xmax, ymax) 74 | max_coordinate = boxes.max() 75 | 76 | # to(): Performs Tensor dtype and/or device conversion 77 | # 为每一个类别/每一层生成一个很大的偏移量 78 | # 这里的to只是让生成tensor的dytpe和device与boxes保持一致 79 | offsets = idxs.to(boxes) * (max_coordinate + 1) 80 | # boxes加上对应层的偏移量后,保证不同类别/层之间boxes不会有重合的现象 81 | boxes_for_nms = boxes + offsets[:, None] 82 | keep = nms(boxes_for_nms, scores, iou_threshold) 83 | return keep 84 | 85 | 86 | def remove_small_boxes(boxes, min_size): 87 | # type: (Tensor, float) -> Tensor 88 | """ 89 | Remove boxes which contains at least one side smaller than min_size. 90 | 移除宽高小于指定阈值的索引 91 | Arguments: 92 | boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format 93 | min_size (float): minimum size 94 | 95 | Returns: 96 | keep (Tensor[K]): indices of the boxes that have both sides 97 | larger than min_size 98 | """ 99 | ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] # 预测boxes的宽和高 100 | # keep = (ws >= min_size) & (hs >= min_size) # 当满足宽,高都大于给定阈值时为True 101 | keep = torch.logical_and(torch.ge(ws, min_size), torch.ge(hs, min_size)) 102 | # nonzero(): Returns a tensor containing the indices of all non-zero elements of input 103 | # keep = keep.nonzero().squeeze(1) 104 | keep = torch.where(keep)[0] 105 | return keep 106 | 107 | 108 | def clip_boxes_to_image(boxes, size): 109 | # type: (Tensor, Tuple[int, int]) -> Tensor 110 | """ 111 | Clip boxes so that they lie inside an image of size `size`. 112 | 裁剪预测的boxes信息,将越界的坐标调整到图片边界上 113 | 114 | Arguments: 115 | boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format 116 | size (Tuple[height, width]): size of the image 117 | 118 | Returns: 119 | clipped_boxes (Tensor[N, 4]) 120 | """ 121 | dim = boxes.dim() 122 | boxes_x = boxes[..., 0::2] # x1, x2 123 | boxes_y = boxes[..., 1::2] # y1, y2 124 | height, width = size 125 | 126 | if torchvision._is_tracing(): 127 | boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) 128 | boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) 129 | boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) 130 | boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) 131 | else: 132 | boxes_x = boxes_x.clamp(min=0, max=width) # 限制x坐标范围在[0,width]之间 133 | boxes_y = boxes_y.clamp(min=0, max=height) # 限制y坐标范围在[0,height]之间 134 | 135 | clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) 136 | return clipped_boxes.reshape(boxes.shape) 137 | 138 | 139 | def box_area(boxes): 140 | """ 141 | Computes the area of a set of bounding boxes, which are specified by its 142 | (x1, y1, x2, y2) coordinates. 143 | 144 | Arguments: 145 | boxes (Tensor[N, 4]): boxes for which the area will be computed. They 146 | are expected to be in (x1, y1, x2, y2) format 147 | 148 | Returns: 149 | area (Tensor[N]): area for each box 150 | """ 151 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 152 | 153 | 154 | def box_iou(boxes1, boxes2): 155 | """ 156 | Return intersection-over-union (Jaccard index) of boxes. 157 | 158 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 159 | 160 | Arguments: 161 | boxes1 (Tensor[N, 4]) 162 | boxes2 (Tensor[M, 4]) 163 | 164 | Returns: 165 | iou (Tensor[N, M]): the NxM matrix containing the pairwise 166 | IoU values for every element in boxes1 and boxes2 167 | """ 168 | area1 = box_area(boxes1) 169 | area2 = box_area(boxes2) 170 | 171 | # When the shapes do not match, 172 | # the shape of the returned output tensor follows the broadcasting rules 173 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # left-top [N,M,2] 174 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # right-bottom [N,M,2] 175 | 176 | wh = (rb - lt).clamp(min=0) # [N,M,2] 177 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 178 | 179 | iou = inter / (area1[:, None] + area2 - inter) 180 | return iou 181 | 182 | -------------------------------------------------------------------------------- /PlaneSAM/PlaneDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | from PIL import Image 7 | from random import randint 8 | 9 | 10 | 11 | class PlaneDataset(Dataset): 12 | def __init__(self, subset='train', transform=None, root_dir=None): 13 | assert subset in ['train', 'val'] 14 | self.subset = subset 15 | self.transform = transform 16 | self.root_dir = os.path.join(root_dir, subset) 17 | self.txt_file = os.path.join(root_dir, subset + '.txt') 18 | 19 | self.data_list = [line.strip() for line in open(self.txt_file, 'r').readlines()] 20 | self.precompute_K_inv_dot_xy_1() 21 | 22 | def get_plane_parameters(self, plane, plane_nums, segmentation): 23 | valid_region = segmentation != 20 24 | 25 | plane = plane[:plane_nums] 26 | 27 | tmp = plane[:, 1].copy() 28 | plane[:, 1] = -plane[:, 2] 29 | plane[:, 2] = tmp 30 | 31 | # convert plane from n * d to n / d 32 | plane_d = np.linalg.norm(plane, axis=1) 33 | # normalize 34 | plane /= plane_d.reshape(-1, 1) 35 | # n / d 36 | plane /= plane_d.reshape(-1, 1) 37 | 38 | h, w = segmentation.shape 39 | plane_parameters = np.ones((3, h, w)) 40 | for i in range(h): 41 | for j in range(w): 42 | d = segmentation[i, j] 43 | if d >= 20: continue 44 | plane_parameters[:, i, j] = plane[d, :] 45 | 46 | # plane_instance parameter, padding zero to fix size 47 | plane_instance_parameter = np.concatenate((plane, np.zeros((20 - plane.shape[0], 3))), axis=0) 48 | return plane_parameters, valid_region, plane_instance_parameter 49 | 50 | def precompute_K_inv_dot_xy_1(self, h=192, w=256): 51 | focal_length = 517.97 52 | offset_x = 320 53 | offset_y = 240 54 | 55 | K = [[focal_length, 0, offset_x], 56 | [0, focal_length, offset_y], 57 | [0, 0, 1]] 58 | 59 | K_inv = np.linalg.inv(np.array(K)) 60 | self.K_inv = K_inv 61 | 62 | K_inv_dot_xy_1 = np.zeros((3, h, w)) 63 | for y in range(h): 64 | for x in range(w): 65 | yy = float(y) / h * 480 66 | xx = float(x) / w * 640 67 | 68 | ray = np.dot(self.K_inv, 69 | np.array([xx, yy, 1]).reshape(3, 1)) 70 | K_inv_dot_xy_1[:, y, x] = ray[:, 0] 71 | 72 | # precompute to speed up processing 73 | self.K_inv_dot_xy_1 = K_inv_dot_xy_1 74 | 75 | def plane2depth(self, plane_parameters, num_planes, segmentation, gt_depth, h=192, w=256): 76 | 77 | depth_map = 1. / np.sum(self.K_inv_dot_xy_1.reshape(3, -1) * plane_parameters.reshape(3, -1), axis=0) 78 | depth_map = depth_map.reshape(h, w) 79 | 80 | # replace non planer region depth using sensor depth map 81 | # 做了一个深度修复,并把非平面区域的深度设为0 82 | depth_map[segmentation == 20] = gt_depth[segmentation == 20] 83 | return depth_map 84 | 85 | def __getitem__(self, index): 86 | if self.subset == 'train': 87 | data_path = self.data_list[index] 88 | else: 89 | data_path = str(index) + '.npz' 90 | data_path = os.path.join(self.root_dir, data_path) 91 | data = np.load(data_path) 92 | 93 | image = data['image'] 94 | image_path = data['image_path'] 95 | info = data['info'] 96 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 97 | image = Image.fromarray(image) 98 | 99 | if self.transform is not None: 100 | image = self.transform(image) 101 | 102 | plane = data['plane'] 103 | num_planes = data['num_planes'][0] 104 | 105 | gt_segmentation = data['segmentation'] 106 | gt_segmentation = gt_segmentation.reshape((192, 256)) 107 | segmentation = np.zeros([21, 192, 256], dtype=np.uint8) 108 | 109 | _, h, w = segmentation.shape 110 | for i in range(num_planes + 1): 111 | # deal with backgroud 112 | if i == num_planes: 113 | seg = gt_segmentation == 20 114 | else: 115 | seg = gt_segmentation == i 116 | 117 | segmentation[i, :, :] = seg.reshape(h, w) 118 | 119 | # surface plane parameters 120 | plane_parameters, valid_region, plane_instance_parameter = \ 121 | self.get_plane_parameters(plane, num_planes, gt_segmentation) 122 | 123 | # since some depth is missing, we use plane to recover those depth following PlaneNet 124 | gt_depth = data['depth'].reshape(192, 256) 125 | depth = self.plane2depth(plane_parameters, num_planes, gt_segmentation, gt_depth).reshape(192, 256) 126 | 127 | # Depth图像需要归一化 128 | if self.transform is not None: 129 | depth = self.transform(depth) 130 | 131 | sample = { 132 | 'image': image, 133 | 'num_planes': num_planes, 134 | 'instance': torch.FloatTensor(segmentation), 135 | # one for planar and zero for non-planar 136 | 'semantic': 1 - torch.FloatTensor(segmentation[num_planes, :, :]).unsqueeze(0), 137 | 'gt_seg': torch.LongTensor(gt_segmentation), 138 | 'depth': depth.to(torch.float32), 139 | 'plane_parameters': torch.FloatTensor(plane_parameters), 140 | 'valid_region': torch.ByteTensor(valid_region.astype(np.uint8)).unsqueeze(0), 141 | 'plane_instance_parameter': torch.FloatTensor(plane_instance_parameter), 142 | 'data_path': data_path 143 | } 144 | 145 | return sample 146 | 147 | def __len__(self): 148 | return len(self.data_list) 149 | 150 | def masks_to_bboxes(self, masks): 151 | """ 152 | 从掩码张量中计算边界框的左上和右下坐标 153 | 参数: 154 | masks: 形状为 [N, H, W] 的二进制掩码张量 155 | 返回值: 156 | bounding_boxes: 形状为 [B, 4] 的边界框坐标张量,包含左上和右下坐标 157 | """ 158 | batch_size, h, w = masks.size() 159 | bounding_boxes = torch.zeros((batch_size, 4), dtype=torch.float32) 160 | 161 | for b in range(batch_size): 162 | mask = masks[b] 163 | 164 | # 找到掩码的非零元素索引 165 | nonzero_indices = torch.nonzero(mask) 166 | 167 | if nonzero_indices.size(0) == 0: 168 | # 如果掩码中没有非零元素,则边界框坐标为零 169 | assert "no mask!" 170 | else: 171 | # 计算边界框的左上和右下坐标 172 | ymin = torch.min(nonzero_indices[:, 0]) / h 173 | xmin = torch.min(nonzero_indices[:, 1]) / w 174 | ymax = torch.max(nonzero_indices[:, 0]) / h 175 | xmax = torch.max(nonzero_indices[:, 1]) / w 176 | bounding_boxes[b] = torch.tensor([(xmin + xmax) / 2, (ymin + ymax) / 2, xmax - xmin, ymax - ymin], dtype=torch.float32) 177 | 178 | return bounding_boxes -------------------------------------------------------------------------------- /PlaneSAM/backbone/resnet101_fpn_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | from .feature_pyramid_network import BackboneWithFPN, LastLevelMaxPool 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None): 14 | super().__init__() 15 | if norm_layer is None: 16 | norm_layer = nn.BatchNorm2d 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 19 | kernel_size=1, stride=1, bias=False) # squeeze channels 20 | self.bn1 = norm_layer(out_channel) 21 | # ----------------------------------------- 22 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 23 | kernel_size=3, stride=stride, bias=False, padding=1) 24 | self.bn2 = norm_layer(out_channel) 25 | # ----------------------------------------- 26 | self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, 27 | kernel_size=1, stride=1, bias=False) # unsqueeze channels 28 | self.bn3 = norm_layer(out_channel * self.expansion) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | 32 | def forward(self, x): 33 | identity = x 34 | if self.downsample is not None: 35 | identity = self.downsample(x) 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class ResNet(nn.Module): 55 | 56 | def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None): 57 | super().__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | self._norm_layer = norm_layer 61 | 62 | self.include_top = include_top 63 | self.in_channel = 64 64 | 65 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 66 | padding=3, bias=False) 67 | self.bn1 = norm_layer(self.in_channel) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 70 | self.layer1 = self._make_layer(block, 64, blocks_num[0]) 71 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 72 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 73 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 74 | if self.include_top: 75 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) 76 | self.fc = nn.Linear(512 * block.expansion, num_classes) 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 81 | 82 | def _make_layer(self, block, channel, block_num, stride=1): 83 | norm_layer = self._norm_layer 84 | downsample = None 85 | if stride != 1 or self.in_channel != channel * block.expansion: 86 | downsample = nn.Sequential( 87 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 88 | norm_layer(channel * block.expansion)) 89 | 90 | layers = [] 91 | layers.append(block(self.in_channel, channel, downsample=downsample, 92 | stride=stride, norm_layer=norm_layer)) 93 | self.in_channel = channel * block.expansion 94 | 95 | for _ in range(1, block_num): 96 | layers.append(block(self.in_channel, channel, norm_layer=norm_layer)) 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | x = self.conv1(x) 102 | x = self.bn1(x) 103 | x = self.relu(x) 104 | x = self.maxpool(x) 105 | 106 | x = self.layer1(x) 107 | x = self.layer2(x) 108 | x = self.layer3(x) 109 | x = self.layer4(x) 110 | 111 | if self.include_top: 112 | x = self.avgpool(x) 113 | x = torch.flatten(x, 1) 114 | x = self.fc(x) 115 | 116 | return x 117 | 118 | 119 | def overwrite_eps(model, eps): 120 | """ 121 | This method overwrites the default eps values of all the 122 | FrozenBatchNorm2d layers of the model with the provided value. 123 | This is necessary to address the BC-breaking change introduced 124 | by the bug-fix at pytorch/vision#2933. The overwrite is applied 125 | only when the pretrained weights are loaded to maintain compatibility 126 | with previous versions. 127 | 128 | Args: 129 | model (nn.Module): The model on which we perform the overwrite. 130 | eps (float): The new value of eps. 131 | """ 132 | for module in model.modules(): 133 | if isinstance(module, FrozenBatchNorm2d): 134 | module.eps = eps 135 | 136 | 137 | def resnet101_fpn_backbone(pretrain_path="", 138 | norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新 139 | trainable_layers=3, 140 | returned_layers=None, 141 | extra_blocks=None): 142 | """ 143 | 搭建resnet50_fpn——backbone 144 | Args: 145 | pretrain_path: resnet50的预训练权重,如果不使用就默认为空 146 | norm_layer: 官方默认的是FrozenBatchNorm2d,即不会更新参数的bn层(因为如果batch_size设置的很小会导致效果更差,还不如不用bn层) 147 | 如果自己的GPU显存很大可以设置很大的batch_size,那么自己可以传入正常的BatchNorm2d层 148 | (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) 149 | trainable_layers: 指定训练哪些层结构 150 | returned_layers: 指定哪些层的输出需要返回 151 | extra_blocks: 在输出的特征层基础上额外添加的层结构 152 | 153 | Returns: 154 | 155 | """ 156 | resnet_backbone = ResNet(Bottleneck, [3, 4, 23, 3], 157 | include_top=False, 158 | norm_layer=norm_layer) 159 | 160 | if isinstance(norm_layer, FrozenBatchNorm2d): 161 | overwrite_eps(resnet_backbone, 0.0) 162 | 163 | if pretrain_path != "": 164 | assert os.path.exists(pretrain_path), "{} is not exist.".format(pretrain_path) 165 | # 载入预训练权重 166 | print(resnet_backbone.load_state_dict(torch.load(pretrain_path), strict=False)) 167 | 168 | # select layers that wont be frozen 169 | assert 0 <= trainable_layers <= 5 170 | layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] 171 | 172 | # 如果要训练所有层结构的话,不要忘了conv1后还有一个bn1 173 | if trainable_layers == 5: 174 | layers_to_train.append("bn1") 175 | 176 | # freeze layers 177 | for name, parameter in resnet_backbone.named_parameters(): 178 | # 只训练不在layers_to_train列表中的层结构 179 | if all([not name.startswith(layer) for layer in layers_to_train]): 180 | parameter.requires_grad_(False) 181 | 182 | if extra_blocks is None: 183 | extra_blocks = LastLevelMaxPool() 184 | 185 | if returned_layers is None: 186 | returned_layers = [1, 2, 3, 4] 187 | # 返回的特征层个数肯定大于0小于5 188 | assert min(returned_layers) > 0 and max(returned_layers) < 5 189 | 190 | # return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} 191 | return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} 192 | 193 | # in_channel 为layer4的输出特征矩阵channel = 2048 194 | in_channels_stage2 = resnet_backbone.in_channel // 8 # 256 195 | # 记录resnet50提供给fpn的每个特征层channel 196 | in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] 197 | # 通过fpn后得到的每个特征层的channel 198 | out_channels = 256 199 | return BackboneWithFPN(resnet_backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) 200 | -------------------------------------------------------------------------------- /PlaneSAM/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from PlaneDataset import PlaneDataset 9 | from model import build_efficient_sam_vitt 10 | from utils.train_tools import save_ckpt, load_ckpt 11 | from utils.eval_tools import compute_iou 12 | from utils.loss import criterion 13 | from utils.make_prompt import preprocess 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | from random import randint 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Segment Any Planes') 19 | parser.add_argument('--data-dir', default='ScanNet', metavar='DIR', 20 | help='path to dataset-D') 21 | parser.add_argument('--cuda', action='store_true', default=True, 22 | help='enables CUDA training') 23 | parser.add_argument('--num-workers', default=8, type=int, metavar='N', 24 | help='number of data loading workers (default: 8)') 25 | parser.add_argument('--epochs', default=15, type=int, metavar='N', 26 | help='number of total epochs to run (default: 150)') 27 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 28 | help='manual epoch number (useful on restarts)') 29 | parser.add_argument('-b', '--batch-size', default=8, type=int, 30 | metavar='N', help='mini-batch size (default: 10)') 31 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 32 | metavar='LR', help='initial learning rate') 33 | parser.add_argument('--weight-decay', '--wd', default=0.01, type=float, 34 | metavar='W', help='weight decay') 35 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 36 | help='momentum') 37 | parser.add_argument('--save-epoch-freq', '-s', default=1, type=int, 38 | metavar='N', help='save epoch frequency (default: 5)') 39 | parser.add_argument('--last-ckpt', default='weights/pre.pth', type=str, metavar='PATH') 40 | parser.add_argument('--lr-decay-rate', default=0.8, type=float, 41 | help='decay rate of learning rate (default: 0.8)') 42 | parser.add_argument('--lr-epoch-per-decay', default=10, type=int, 43 | help='epoch of per decay of learning rate (default: 10)') 44 | parser.add_argument('--ckpt-dir', default='weights', metavar='DIR', 45 | help='path to save checkpoints') 46 | 47 | args = parser.parse_args() 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | 50 | image_w = 256 51 | image_h = 192 52 | 53 | def train(): 54 | train_data = PlaneDataset(subset="train", 55 | transform=transforms.Compose([ 56 | transforms.ToTensor()]), 57 | root_dir=args.data_dir 58 | ) 59 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, 60 | pin_memory=False) 61 | val_data = PlaneDataset(subset="val", 62 | transform=transforms.Compose([ 63 | transforms.ToTensor()]), 64 | root_dir=args.data_dir 65 | ) 66 | val_loader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=args.num_workers, 67 | pin_memory=False) 68 | 69 | num_train = len(train_data) 70 | 71 | model = build_efficient_sam_vitt() 72 | 73 | for p in model.prompt_encoder.parameters(): 74 | p.requires_grad = False 75 | 76 | if torch.cuda.device_count() >= 1: 77 | print("Let's use", torch.cuda.device_count(), "GPUs!") 78 | model = nn.DataParallel(model) 79 | 80 | model.to(device) 81 | 82 | optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr, 83 | weight_decay=args.weight_decay) 84 | 85 | global_step = 0 86 | 87 | if args.last_ckpt: 88 | global_step, _ = load_ckpt(model, None, args.last_ckpt, device, NUM_GPUS=2) 89 | 90 | # 余弦学习率调度 91 | scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0) 92 | 93 | for _ in range(args.start_epoch): 94 | optimizer.step() 95 | scheduler.step() 96 | 97 | for epoch in range(int(args.start_epoch), args.epochs): 98 | 99 | local_count = 0 100 | if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch: 101 | save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch, 102 | local_count, num_train) 103 | 104 | # 训练 105 | model.train() 106 | num_batch = 0 107 | mean_loss = 0 108 | for sample in tqdm(train_loader, desc=f"Train Epoch [{epoch + 1}/{args.epochs}]"): 109 | image = sample['image'].to(device) # [B, C, H, W] 110 | num_planes = sample['num_planes'] 111 | sel_planes = [randint(0, x - 1) for x in num_planes] 112 | target = sample['instance'].to(device) # [B, 21, H, W],第0-num_planes-1为平面掩码,第num_planes为背景(非平面) 113 | target = target[torch.arange(len(num_planes)), sel_planes, ...] 114 | depth = sample['depth'].to(device) # [B, 1, H, W] 115 | 116 | optimizer.zero_grad() 117 | 118 | input_points, input_labels = preprocess(target, num_points=0, box=True) 119 | 120 | # [B, num_prompt, per_prompt_mask, H, W] 121 | pred_masks, pred_ious = model(image, depth, input_points, input_labels) 122 | # 训练只给一个提示 123 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 124 | # 125 | pred_masks = pred_masks.view(b, -1, h, w) 126 | pred_ious = pred_ious.view(b, -1) 127 | # 遍历3个掩码 128 | pred_masks = pred_masks.permute(1, 0, 2, 3) 129 | pred_ious = pred_ious.permute(1, 0) 130 | 131 | loss = [] 132 | for mask, iou in zip(pred_masks, pred_ious): 133 | loss.append(criterion(mask, iou, target)) 134 | loss = min(loss) 135 | 136 | num_batch += 1 137 | mean_loss += loss 138 | 139 | loss.backward() 140 | optimizer.step() 141 | 142 | local_count += image.data.shape[0] 143 | global_step += 1 144 | 145 | scheduler.step() 146 | 147 | mean_loss /= num_batch 148 | print('Epoch: {} mean_loss: {}'.format(epoch + 1, mean_loss)) 149 | 150 | # 评估(use box to test) 151 | model.eval() 152 | totoal_iou = .0 153 | totoal_num = 0 154 | with torch.no_grad(): 155 | for sample in tqdm(val_loader, desc=f"Eval Epoch[{epoch + 1}/{args.epochs}]"): 156 | image = sample['image'].to(device) 157 | num_planes = sample['num_planes'][0] 158 | target = sample['instance'].to(device) 159 | target = target[:, randint(0, num_planes - 1), ...] 160 | depth = sample['depth'].to(device) 161 | 162 | # 预处理 163 | input_points, input_labels = preprocess(target, num_points=0, box=True, noise_ratio=0.1) 164 | 165 | # [B, num_prompt, per_prompt_mask, H, W] 166 | pred_masks, pred_ious = model(image, depth, input_points, input_labels) 167 | # 二值掩码 168 | pred_masks = (pred_masks > 0.).float() 169 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 170 | # [num_masks, B, H, W] 171 | pred_masks = pred_masks.view(b, -1, h, w).permute(1, 0, 2, 3) 172 | pred_ious = pred_ious.view(b, -1).permute(1, 0) 173 | 174 | # 计算iou 175 | best_id = torch.argmax(pred_ious, dim=0) 176 | best_mask = pred_masks[best_id, torch.arange(b)] 177 | best_iou = compute_iou(best_mask, target) 178 | totoal_num += len(best_iou) 179 | totoal_iou += sum(best_iou) 180 | 181 | mIou = float(totoal_iou / totoal_num) 182 | print('Epoch: {} mIou: {:.2}\n'.format(epoch + 1, mIou)) 183 | with open('log/logger.txt', 'a') as f: 184 | f.write('Epoch: {} mIou: {:.4} mean_loss: {}\n'.format(epoch + 1, mIou, mean_loss)) 185 | 186 | save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs, 187 | 0, num_train) 188 | 189 | print("Training completed ") 190 | 191 | 192 | if __name__ == '__main__': 193 | 194 | if not os.path.exists(args.ckpt_dir): 195 | os.mkdir(args.ckpt_dir) 196 | 197 | train() 198 | -------------------------------------------------------------------------------- /PlaneSAM/visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tifffile 4 | import time 5 | import numpy as np 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | from efficient_sam import build_efficient_sam_vitt 11 | # from raw_efficient_sam import build_efficient_sam_vitt 12 | from FasterRCNN.build_FasterRCNN import create_model 13 | from PlaneDataset import PlaneDataset 14 | from Nyuv2Dataset import Nyuv2Dataset, ToTensor 15 | # from S2D3DSDataset import S2d3dsDataset, ToTensor 16 | from torch.utils.data import DataLoader 17 | from utils.make_prompt import preprocess 18 | from utils.utils import load_ckpt 19 | from utils.visual_tools import map_masks_to_colors 20 | from utils.eval_tools import box_to_prompt, MatchSegmentation, evaluateMasks, match_boxes_gt 21 | from utils.box_ops import masks_to_boxes 22 | from PIL import Image 23 | from tqdm import tqdm 24 | 25 | 26 | parser = argparse.ArgumentParser(description='Segment Any Planes') 27 | parser.add_argument('--data-dir', default='mp3d-plane', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('--num-workers', default=8, type=int, metavar='N', 30 | help='number of data loading workers (default: 8)') 31 | parser.add_argument('-o', '--output', default='result', metavar='DIR', 32 | help='path to output') 33 | parser.add_argument('--cuda', action='store_true', default=True, 34 | help='enables CUDA training') 35 | parser.add_argument('--last-ckpt', default='./model/usepre_best.pth', type=str, metavar='PATH', 36 | help='path to latest checkpoint (default: none)') 37 | parser.add_argument('--detector-ckpt', default='FasterRCNN_weights/resNet101Fpn-model-9.pth', type=str, metavar='PATH', 38 | help='path to detector checkpoint (default: none)') 39 | 40 | args = parser.parse_args() 41 | device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu") 42 | 43 | 44 | def inference(): 45 | model = build_efficient_sam_vitt() 46 | detector = create_model(num_classes=2, load_pretrain_weights=False) 47 | detector.load_state_dict(torch.load(args.detector_ckpt)['model']) 48 | 49 | if torch.cuda.device_count() >= 1: 50 | print("Let's use", torch.cuda.device_count(), "GPUs!") 51 | model = nn.DataParallel(model) 52 | 53 | model.to(device) 54 | detector.to(device) 55 | 56 | # val_data = PlaneDataset(subset="train", 57 | # transform=transforms.Compose([ 58 | # transforms.ToTensor()]), 59 | # root_dir=args.data_dir 60 | # ) 61 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 62 | # pin_memory=False) 63 | 64 | val_data = Nyuv2Dataset(subset="val", 65 | transform=transforms.Compose([ 66 | ToTensor()]), 67 | datafolder=args.data_dir 68 | ) 69 | val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 70 | pin_memory=False) 71 | 72 | # val_data = S2d3dsDataset(subset="val", 73 | # transform=transforms.Compose([ 74 | # ToTensor()]), 75 | # datafolder=args.data_dir 76 | # ) 77 | # val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=args.num_workers, 78 | # pin_memory=False) 79 | 80 | if args.last_ckpt: 81 | load_ckpt(model, None, args.last_ckpt, device) 82 | else: 83 | print('no ckpt!') 84 | 85 | model.eval() 86 | detector.eval() 87 | 88 | total_time = 0 89 | 90 | with torch.no_grad(): 91 | for batch_idx, sample in enumerate(tqdm(val_loader)): 92 | num_planes = sample['num_planes'][0] 93 | image = sample['image'].to(device) 94 | target = sample['instance'].to(device) 95 | target = target[:, :num_planes, :, :].permute(1, 0, 2, 3) # [num_planes, B, H, W] 96 | gt_boxes = masks_to_boxes(target) 97 | depth = sample['depth'].to(device) 98 | 99 | # # 预测边界框 to prompt 100 | # outputs = detector(image)[0] 101 | # boxes = outputs['boxes'] 102 | # batch_points, batch_labels = match_boxes_gt(boxes, gt_boxes) 103 | 104 | # # use preprocess 105 | # batch_points, batch_labels = preprocess(target.flatten(0, 1), num_points=0, box=True) 106 | 107 | pred_masks = [] 108 | pred_ious = [] 109 | 110 | start_time = time.time() 111 | 112 | for input_points, input_labels in zip(batch_points, batch_labels): 113 | input_points, input_labels = input_points.unsqueeze(0), input_labels.unsqueeze(0) 114 | # [B, num_prompt, per_prompt_mask, H, W] 115 | pred_mask, pred_iou = model(image, depth, input_points, input_labels) 116 | pred_masks.append(pred_mask) 117 | pred_ious.append(pred_iou) 118 | 119 | end_time = time.time() 120 | total_time += end_time - start_time 121 | 122 | # 不增加新的维度 123 | pred_masks = torch.cat(pred_masks, dim=0) 124 | pred_ious = torch.cat(pred_ious, dim=0) 125 | # 二值掩码 126 | pred_masks = (pred_masks > 0.).float() 127 | b, num_prompt, per_prompt_mask, h, w = pred_masks.shape 128 | # [3, B, H, W] 129 | # [3, B] 130 | pred_masks = pred_masks.view(b, -1, h, w).permute(1, 0, 2, 3) 131 | pred_ious = pred_ious.view(b, -1).permute(1, 0) 132 | best_id = torch.argmax(pred_ious, dim=0) 133 | # [num_planes, H, W] 134 | best_mask = pred_masks[best_id, torch.arange(b)] 135 | 136 | # # 将pred与gt匹配 137 | # target = target.squeeze(1) 138 | # matching = MatchSegmentation(best_mask, target) 139 | # matched_pred_indices = [] 140 | # matched_gt_indices = [] 141 | # used = [] 142 | # for i, a in enumerate(matching): 143 | # if a not in used: 144 | # matched_pred_indices.append(i) 145 | # matched_gt_indices.append(a) 146 | # used.append(a) 147 | # matched_pred_indices = torch.as_tensor(matched_pred_indices) 148 | # matched_gt_indices = torch.as_tensor(matched_gt_indices) 149 | # prediction = torch.zeros_like(target) 150 | # prediction[matched_gt_indices] = best_mask[matched_pred_indices] 151 | 152 | # RI, VoI, SC = evaluateMasks(best_mask, target) 153 | # if SC > 0.7: 154 | # continue 155 | 156 | # visual 157 | prediction = best_mask.cpu().numpy().astype(np.uint8) 158 | # pred 159 | rgb_image = map_masks_to_colors(prediction) 160 | # gt_rgb 161 | image = image.squeeze(0).cpu().numpy() 162 | image *= 255 163 | image = image.astype(np.uint8).transpose(1, 2, 0) 164 | image = np.clip(image, 0, 255) 165 | # gt 166 | target = target.squeeze(1) 167 | target = target.cpu().numpy().astype(np.uint8) 168 | gt = map_masks_to_colors(target) 169 | # depth 170 | depth = depth.squeeze(0).squeeze(0).cpu().numpy() 171 | depth = (depth * 255 / (depth.max())).astype(np.uint8) 172 | depth = cv2.applyColorMap(depth, cv2.COLORMAP_JET) 173 | # # 拼接 174 | # rgbd_image = np.concatenate((image, depth, rgb_image, gt), axis=1) 175 | # rgbd_image = Image.fromarray(rgbd_image) 176 | # rgbd_image.save("raw_sam_result/all/" + str(batch_idx) + '.jpg') 177 | # print(f"save {batch_idx}.jpg") 178 | 179 | tifffile.imwrite("mp3d_result/input/input_" + str(batch_idx) + '.tif', image, resolution=(600, 600)) 180 | tifffile.imwrite("mp3d_result/gt/gt_" + str(batch_idx) + '.tif', gt, resolution=(600, 600)) 181 | tifffile.imwrite("mp3d_result/predict/predict_" + str(batch_idx) + '.tif', rgb_image, resolution=(600, 600)) 182 | tifffile.imwrite("mp3d_result/depth/depth_" + str(batch_idx) + '.tif', depth, resolution=(600, 600)) 183 | print(f"save {batch_idx}") 184 | 185 | print("total_time: {}".format(total_time)) 186 | 187 | if __name__ == '__main__': 188 | if not os.path.exists(args.output): 189 | os.mkdir(args.output) 190 | inference() 191 | -------------------------------------------------------------------------------- /PlaneSAM/backbone/feature_pyramid_network.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | 8 | from torch.jit.annotations import Tuple, List, Dict 9 | 10 | 11 | class IntermediateLayerGetter(nn.ModuleDict): 12 | """ 13 | Module wrapper that returns intermediate layers from a model 14 | It has a strong assumption that the modules have been registered 15 | into the model in the same order as they are used. 16 | This means that one should **not** reuse the same nn.Module 17 | twice in the forward if you want this to work. 18 | Additionally, it is only able to query submodules that are directly 19 | assigned to the model. So if `model` is passed, `model.feature1` can 20 | be returned, but not `model.feature1.layer2`. 21 | Arguments: 22 | model (nn.Module): model on which we will extract the features 23 | return_layers (Dict[name, new_name]): a dict containing the names 24 | of the modules for which the activations will be returned as 25 | the key of the dict, and the value of the dict is the name 26 | of the returned activation (which the user can specify). 27 | """ 28 | __annotations__ = { 29 | "return_layers": Dict[str, str], 30 | } 31 | 32 | def __init__(self, model, return_layers): 33 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 34 | raise ValueError("return_layers are not present in model") 35 | 36 | orig_return_layers = return_layers 37 | return_layers = {str(k): str(v) for k, v in return_layers.items()} 38 | layers = OrderedDict() 39 | 40 | # 遍历模型子模块按顺序存入有序字典 41 | # 只保存layer4及其之前的结构,舍去之后不用的结构 42 | for name, module in model.named_children(): 43 | layers[name] = module 44 | if name in return_layers: 45 | del return_layers[name] 46 | if not return_layers: 47 | break 48 | 49 | super().__init__(layers) 50 | self.return_layers = orig_return_layers 51 | 52 | def forward(self, x): 53 | out = OrderedDict() 54 | # 依次遍历模型的所有子模块,并进行正向传播, 55 | # 收集layer1, layer2, layer3, layer4的输出 56 | for name, module in self.items(): 57 | x = module(x) 58 | if name in self.return_layers: 59 | out_name = self.return_layers[name] 60 | out[out_name] = x 61 | return out 62 | 63 | 64 | class FeaturePyramidNetwork(nn.Module): 65 | """ 66 | Module that adds a FPN from on top of a set of feature maps. This is based on 67 | `"Feature Pyramid Network for Object Detection"