├── .gitattributes ├── .gitignore ├── DATASET.md ├── README.md ├── __init__.py ├── configs ├── semantic_sam_only_sa-1b_swinL.yaml ├── semantic_sam_only_sa-1b_swinT.yaml └── semantic_sam_reproduce_sam_swinL.yaml ├── datasets ├── __init__.py ├── build.py ├── dataset_mappers │ ├── __init__.py │ ├── coco_instance_new_baseline_dataset_mapper.py │ ├── coco_interactive_panoptic_new_baseline_dataset_mapper.py │ ├── coco_panoptic_new_baseline_dataset_mapper.py │ ├── dataset_mapper_filterbybox.py │ ├── imagenet_dataset_mapper.py │ ├── inference_mapper_with_gt.py │ ├── lvis_dataset_mapper.py │ ├── mask_former_instance_dataset_mapper.py │ ├── mask_former_interactive_panoptic_dataset_mapper.py │ ├── mask_former_panoptic_dataset_mapper.py │ ├── mask_former_semantic_dataset_mapper.py │ ├── o365_instance_new_baseline_dataset_mapper.py │ ├── part_data_filter_whole_new_instance_dataset_mapper.py │ ├── pascal_instance_new_baseline_dataset_mapper.py │ ├── sam_baseline_dataset_mapper.py │ └── sam_baseline_dataset_mapper_json.py ├── evaluation │ ├── __init__.py │ ├── instance_evaluation.py │ ├── interactive_evaluation.py │ ├── panoptic_evaluation.py │ ├── pascal_part_evaluation.py │ └── segmentation_evaluation.py ├── registration │ ├── __init__.py │ ├── register_ade20k_full.py │ ├── register_ade20k_instance.py │ ├── register_ade20k_panoptic.py │ ├── register_bdd100k_panoseg.py │ ├── register_bdd100k_semseg.py │ ├── register_coco_panoptic_annos_caption.py │ ├── register_coco_panoptic_annos_caption_grounding.py │ ├── register_coco_panoptic_annos_caption_grounding_interactive.py │ ├── register_coco_panoptic_annos_caption_interactive.py │ ├── register_coco_panoptic_annos_semseg.py │ ├── register_coco_panoptic_annos_semseg_interactive.py │ ├── register_coco_panoptic_annos_semseg_interactive_jointboxpoint.py │ ├── register_coco_stuff_10k.py │ ├── register_imagenet_cls.py │ ├── register_lvis_eval.py │ ├── register_object365_od.py │ ├── register_paco_part_all.py │ ├── register_partimagenet_part_all.py │ ├── register_pascal_part_all.py │ ├── register_pascal_part_all_interactive.py │ ├── register_refcoco_dataset.py │ ├── register_sam.py │ ├── register_sam_json.py │ ├── register_sam_json_val.py │ ├── register_sam_mnode.py │ ├── register_scannet_panoptic.py │ ├── register_scannet_semseg.py │ ├── register_sunrgbd_semseg.py │ └── register_vlp_datasets.py └── utils │ ├── __init__.py │ ├── semseg_loader.py │ └── tsv │ ├── __init__.py │ ├── io_common.py │ └── tsv_io.py ├── demo.py ├── demo_auto_generation.py ├── examples ├── 4.png ├── 5.png ├── castle.png ├── corgi1.webp ├── corgi2.jpg ├── dog.jpg ├── fries1.png ├── fries2.png ├── img.png ├── levels_dog.png ├── minecraft1.jpg ├── minecraft2.png ├── placeholder.png ├── ref_cat.jpeg ├── ref_vase.JPG ├── river1.mp4 ├── river1.png ├── river1.wav ├── river1_mask.png ├── river2.png ├── river2_mask.png ├── tank.png ├── truck.jpg ├── zebras1.jpg └── zebras2.jpg ├── pyproject.toml ├── requirements.txt ├── semantic_sam ├── BaseModel.py ├── __init__.py ├── architectures │ ├── __init__.py │ ├── build.py │ ├── interactive_mask_dino.py │ └── registry.py ├── backbone │ ├── __init__.py │ ├── backbone.py │ ├── build.py │ ├── focal.py │ ├── focal_dw.py │ ├── registry.py │ ├── swin.py │ └── swin_new.py ├── body │ ├── __init__.py │ ├── build.py │ ├── decoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── interactive_mask_dino.py │ │ ├── modules.py │ │ ├── registry.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── dino_decoder.py │ │ │ └── utils.py │ ├── encoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── encoder_deform.py │ │ ├── ops │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn_func.py │ │ │ ├── make.sh │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ ├── setup.py │ │ │ ├── src │ │ │ │ ├── cpu │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ ├── cuda │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ ├── ms_deform_attn.h │ │ │ │ └── vision.cpp │ │ │ └── test.py │ │ ├── registry.py │ │ └── transformer_encoder_fpn.py │ ├── general_head.py │ ├── registry.py │ └── transformer_blocks.py ├── build_semantic_sam.py ├── language │ ├── LangEncoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── registry.py │ │ └── transformer.py │ ├── __init__.py │ ├── build.py │ ├── encoder.py │ ├── registry.py │ └── vlpencoder.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── criterion_interactive_many_to_many.py │ ├── criterion_interactive_many_to_one.py │ ├── many2many_matcher.py │ ├── matcher.py │ ├── point_features.py │ ├── position_encoding.py │ └── postprocessing.py └── utils │ ├── __init__.py │ ├── box_ops.py │ ├── config.py │ └── misc.py ├── tasks ├── __init__.py ├── automatic_mask_generator.py ├── interactive_idino_m2m.py ├── interactive_idino_m2m_auto.py └── interactive_predictor.py ├── train_net.py └── utils ├── Config.py ├── __init__.py ├── arguments.py ├── constants.py ├── dist.py ├── distributed.py ├── misc.py ├── model.py ├── prompt_engineering.py ├── sam_utils ├── __init__.py ├── amg.py ├── onnx.py └── transforms.py └── visualizer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .gitignore 2 | .gitattributes 3 | __pycache__ 4 | */__pycache__ 5 | */*/__pycache__ 6 | */*/*/__pycache__ 7 | 8 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # Preparing Dataset 2 | Our dataloader follows [Detectron2](https://github.com/facebookresearch/detectron2) contains (1) A dataset registrator. (2) A dataset mapper. (3) A dataset loader. We modify the dataset registrator and mapper for different datasets. 3 | 4 | ## SA-1B Training 5 | Please follow [SAM](https://github.com/facebookresearch/segment-anything) to prepare your datasets. 6 | We recommend you to transfer SAM data into the formats of TSV for faster data loading. We also provide a [tsv loader](datasets/dataset_mappers/sam_baseline_dataset_mapper.py) for you. 7 | ### TSV data preparation 8 | 9 | ```python 10 | import json 11 | import base64 12 | 13 | tsv_file = '/your/save/path' 14 | index_file = '/your/save/path' 15 | f1 = open(tsv_file, 'w') 16 | f2 = open(index_file, 'w') 17 | """ 18 | Example code: write a single image and its json annotation to 19 | tsv_file: save the image and annotation (forms one data piece) 20 | index_file: save the tsv index of each data piece 21 | """ 22 | ann_start = 0 23 | json_file = '/your/sam/path/json' 24 | image_file = '/your/sam/path/image' 25 | ann = json.load(json_file) 26 | anno = json.dumps(ann) 27 | img = open(image_file, 'rb').read() 28 | img = base64.b64encode(img).decode('utf-8') 29 | lent = 0 30 | # save image_file name 31 | length = f1.write("%s\t"%image_file) 32 | lent += length 33 | # save annotation 34 | length = f1.write("%s\t"%anno) 35 | lent += length 36 | # save image 37 | length = f1.write("%s\n"%img) 38 | lent += length 39 | f2.write("%d %d\n"%(ann_start, lent)) 40 | ann_start += lent 41 | ``` 42 | You can refer to this example format to write the original SAM data into tsv format for faster data processing. 43 | ### Json file 44 | If you wanna use the original Json format in SA-1B, you can use [this mapper](datasets/dataset_mappers/sam_baseline_dataset_mapper_json.py) we provide. 45 | You can build a `image_list.da` to combine all the json and image file of a directory. Here is the example code. 46 | ```python 47 | import torch 48 | import os 49 | sam_path='/your/sam/path/sa_000000' 50 | save_path='/your/sam/path/sa_000000/image_list.da' 51 | f_save = open(save_path, 'wb') 52 | a=[] 53 | files = os.listdir(sam_path) 54 | for f in files: 55 | if f.split('.')[-1]=='jpg': 56 | a.append({'img_name': os.path.join(sam_path, f), 'ann_name': os.path.join(sam_path, f.split('.')[0]+'.json')}) 57 | torch.save(a, f_save) 58 | f_save.close() 59 | ``` 60 | ## COCO 61 | Please Refer to [MaskDINO](https://github.com/IDEA-Research/MaskDINO/blob/main/README.md). 62 | 63 | 64 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/__init__.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import registration 2 | from .build import * -------------------------------------------------------------------------------- /datasets/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .coco_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper 3 | from .coco_panoptic_new_baseline_dataset_mapper import COCOPanopticNewBaselineDatasetMapper 4 | from .coco_interactive_panoptic_new_baseline_dataset_mapper import COCOInteractivePanopticNewBaselineDatasetMapper 5 | from .mask_former_instance_dataset_mapper import MaskFormerInstanceDatasetMapper 6 | from .mask_former_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapper 7 | from .mask_former_interactive_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapperInteractive 8 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper 9 | from .imagenet_dataset_mapper import ImageNetDatasetMapper 10 | from .o365_instance_new_baseline_dataset_mapper import O365InstanceNewBaselineDatasetMapper 11 | from .sam_baseline_dataset_mapper import build_transform_gen as sam_transform_gen 12 | from .sam_baseline_dataset_mapper import SamBaselineDatasetMapper 13 | from .sam_baseline_dataset_mapper_json import SamBaselineDatasetMapperJSON 14 | from .dataset_mapper_filterbybox import DatasetMapperFilterByBox 15 | from .pascal_instance_new_baseline_dataset_mapper import PascalInstanceNewBaselineDatasetMapper 16 | from .part_data_filter_whole_new_instance_dataset_mapper import PartFilterWholeInstanceNewBaselineDatasetMapper 17 | from .inference_mapper_with_gt import CoCoInferenceDatasetMapper -------------------------------------------------------------------------------- /datasets/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 IDEA. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from Mask2Former https://github.com/facebookresearch/Mask2Former by Feng Li. 6 | import copy 7 | import logging 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from detectron2.config import configurable 13 | from detectron2.data import detection_utils as utils 14 | from detectron2.data import transforms as T 15 | from detectron2.data.transforms import TransformGen 16 | from detectron2.structures import BitMasks, Boxes, Instances 17 | 18 | __all__ = ["COCOPanopticNewBaselineDatasetMapper"] 19 | 20 | 21 | def build_transform_gen(cfg, is_train): 22 | """ 23 | Create a list of default :class:`Augmentation` from config. 24 | Now it includes resizing and flipping. 25 | Returns: 26 | list[Augmentation] 27 | """ 28 | assert is_train, "Only support training augmentation" 29 | image_size = cfg.INPUT.IMAGE_SIZE 30 | min_scale = cfg.INPUT.MIN_SCALE 31 | max_scale = cfg.INPUT.MAX_SCALE 32 | 33 | augmentation = [] 34 | 35 | if cfg.INPUT.RANDOM_FLIP != "none": 36 | augmentation.append( 37 | T.RandomFlip( 38 | horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", 39 | vertical=cfg.INPUT.RANDOM_FLIP == "vertical", 40 | ) 41 | ) 42 | 43 | augmentation.extend([ 44 | T.ResizeScale( 45 | min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size 46 | ), 47 | T.FixedSizeCrop(crop_size=(image_size, image_size)), 48 | ]) 49 | 50 | return augmentation 51 | 52 | 53 | # This is specifically designed for the COCO dataset. 54 | class COCOPanopticNewBaselineDatasetMapper: 55 | """ 56 | A callable which takes a dataset dict in Detectron2 Dataset format, 57 | and map it into a format used by MaskFormer. 58 | 59 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 60 | 61 | The callable currently does the following: 62 | 63 | 1. Read the image from "file_name" 64 | 2. Applies geometric transforms to the image and annotation 65 | 3. Find and applies suitable cropping to the image and annotation 66 | 4. Prepare image and annotation to Tensors 67 | """ 68 | 69 | @configurable 70 | def __init__( 71 | self, 72 | is_train=True, 73 | *, 74 | tfm_gens, 75 | image_format, 76 | ): 77 | """ 78 | NOTE: this interface is experimental. 79 | Args: 80 | is_train: for training or inference 81 | augmentations: a list of augmentations or deterministic transforms to apply 82 | crop_gen: crop augmentation 83 | tfm_gens: data augmentation 84 | image_format: an image format supported by :func:`detection_utils.read_image`. 85 | """ 86 | self.tfm_gens = tfm_gens 87 | logging.getLogger(__name__).info( 88 | "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format( 89 | str(self.tfm_gens) 90 | ) 91 | ) 92 | 93 | self.img_format = image_format 94 | self.is_train = is_train 95 | 96 | @classmethod 97 | def from_config(cls, cfg, is_train=True): 98 | # Build augmentation 99 | tfm_gens = build_transform_gen(cfg, is_train) 100 | 101 | ret = { 102 | "is_train": is_train, 103 | "tfm_gens": tfm_gens, 104 | "image_format": cfg.INPUT.FORMAT, 105 | } 106 | return ret 107 | 108 | def __call__(self, dataset_dict): 109 | """ 110 | Args: 111 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 112 | 113 | Returns: 114 | dict: a format that builtin models in detectron2 accept 115 | """ 116 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 117 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 118 | utils.check_image_size(dataset_dict, image) 119 | 120 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 121 | image_shape = image.shape[:2] # h, w 122 | 123 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 124 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 125 | # Therefore it's important to use torch.Tensor. 126 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 127 | 128 | if not self.is_train: 129 | # USER: Modify this if you want to keep them for some reason. 130 | dataset_dict.pop("annotations", None) 131 | return dataset_dict 132 | 133 | if "pan_seg_file_name" in dataset_dict: 134 | pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") 135 | segments_info = dataset_dict["segments_info"] 136 | 137 | # apply the same transformation to panoptic segmentation 138 | pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) 139 | 140 | from panopticapi.utils import rgb2id 141 | 142 | pan_seg_gt = rgb2id(pan_seg_gt) 143 | 144 | instances = Instances(image_shape) 145 | classes = [] 146 | masks = [] 147 | for segment_info in segments_info: 148 | class_id = segment_info["category_id"] 149 | if not segment_info["iscrowd"]: 150 | classes.append(class_id) 151 | masks.append(pan_seg_gt == segment_info["id"]) 152 | 153 | classes = np.array(classes) 154 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 155 | if len(masks) == 0: 156 | # Some image does not have annotation (all ignored) 157 | instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) 158 | instances.gt_boxes = Boxes(torch.zeros((0, 4))) 159 | else: 160 | masks = BitMasks( 161 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 162 | ) 163 | instances.gt_masks = masks.tensor 164 | instances.gt_boxes = masks.get_bounding_boxes() 165 | 166 | dataset_dict["instances"] = instances 167 | 168 | return dataset_dict 169 | -------------------------------------------------------------------------------- /datasets/dataset_mappers/dataset_mapper_filterbybox.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | import numpy as np 5 | from typing import List, Optional, Union 6 | import torch 7 | 8 | from detectron2.config import configurable 9 | from detectron2.data import detection_utils as utils 10 | from detectron2.data import transforms as T 11 | from detectron2.data.dataset_mapper import DatasetMapper 12 | 13 | def filter_empty_instances_by_box( 14 | instances, by_box=True, by_mask=False, box_threshold=1e-5, return_mask=False 15 | ): 16 | assert by_box or by_mask 17 | r = [] 18 | if by_box: 19 | r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) 20 | if instances.has("gt_masks") and by_mask: 21 | r.append(instances.gt_masks.nonempty()) 22 | 23 | # TODO: can also filter visible keypoints 24 | 25 | if not r: 26 | return instances 27 | m = r[0] 28 | for x in r[1:]: 29 | m = m & x 30 | if return_mask: 31 | return instances[m], m 32 | return instances[m] 33 | 34 | 35 | class DatasetMapperFilterByBox(DatasetMapper): 36 | def _transform_annotations(self, dataset_dict, transforms, image_shape): 37 | # USER: Modify this if you want to keep them for some reason. 38 | for anno in dataset_dict["annotations"]: 39 | if not self.use_instance_mask: 40 | anno.pop("segmentation", None) 41 | if not self.use_keypoint: 42 | anno.pop("keypoints", None) 43 | 44 | # USER: Implement additional transformations if you have other types of data 45 | annos = [ 46 | utils.transform_instance_annotations( 47 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices 48 | ) 49 | for obj in dataset_dict.pop("annotations") 50 | if obj.get("iscrowd", 0) == 0 51 | ] 52 | instances = utils.annotations_to_instances( 53 | annos, image_shape, mask_format=self.instance_mask_format 54 | ) 55 | 56 | # After transforms such as cropping are applied, the bounding box may no longer 57 | # tightly bound the object. As an example, imagine a triangle object 58 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight 59 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to 60 | # the intersection of original bounding box and the cropping box. 61 | instances.gt_masks = instances.gt_masks.tensor 62 | if self.recompute_boxes: 63 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 64 | dataset_dict["instances"] = filter_empty_instances_by_box(instances) 65 | -------------------------------------------------------------------------------- /datasets/dataset_mappers/imagenet_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | from PIL import Image 10 | # import logging 11 | 12 | import cv2 13 | import numpy as np 14 | 15 | import torch 16 | from torchvision import transforms 17 | 18 | from semantic_sam.utils import configurable 19 | 20 | __all__ = ["ImageNetDatasetMapper"] 21 | 22 | 23 | # This is specifically designed for the COCO dataset. 24 | class ImageNetDatasetMapper: 25 | """ 26 | A callable which takes a dataset dict in Detectron2 Dataset format, 27 | and map it into a format used by MaskFormer. 28 | 29 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 30 | 31 | The callable currently does the following: 32 | 33 | 1. Read the image from "file_name" 34 | 2. Applies geometric transforms to the image and annotation 35 | 3. Find and applies suitable cropping to the image and annotation 36 | 4. Prepare image and annotation to Tensors 37 | """ 38 | 39 | @configurable 40 | def __init__( 41 | self, 42 | is_train=True, 43 | size_train=None, 44 | size_test=None, 45 | size_crop=None, 46 | ): 47 | """ 48 | NOTE: this interface is experimental. 49 | Args: 50 | is_train: for training or inference 51 | augmentations: a list of augmentations or deterministic transforms to apply 52 | tfm_gens: data augmentation 53 | image_format: an image format supported by :func:`detection_utils.read_image`. 54 | """ 55 | self.is_train = is_train 56 | self.size_train = size_train 57 | self.size_test = size_test 58 | self.size_crop = size_crop 59 | 60 | t = [] 61 | t.append(transforms.Resize(size_crop, interpolation=Image.BICUBIC)) 62 | t.append(transforms.CenterCrop(size_test)) 63 | self.transform = transforms.Compose(t) 64 | 65 | @classmethod 66 | def from_config(cls, cfg, is_train=True): 67 | ret = { 68 | "is_train": is_train, 69 | "size_train": cfg['INPUT']['SIZE_TRAIN'], 70 | "size_test": cfg['INPUT']['SIZE_TEST'], 71 | "size_crop": cfg['INPUT']['SIZE_CROP'] 72 | } 73 | return ret 74 | 75 | def __call__(self, dataset_dict): 76 | """ 77 | Args: 78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 79 | 80 | Returns: 81 | dict: a format that builtin models in detectron2 accept 82 | """ 83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 84 | file_name = dataset_dict['file_name'] 85 | image = Image.open(file_name).convert('RGB') 86 | 87 | if self.is_train == False: 88 | image = self.transform(image) 89 | image = torch.from_numpy(np.asarray(image).copy()) 90 | image = image.permute(2,0,1) 91 | 92 | dataset_dict['image'] = image 93 | dataset_dict['height'] = image.shape[1] 94 | dataset_dict['width'] = image.shape[2] 95 | return dataset_dict -------------------------------------------------------------------------------- /datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .instance_evaluation import * 2 | from .segmentation_evaluation import * 3 | from .panoptic_evaluation import * 4 | from .pascal_part_evaluation import * 5 | from .interactive_evaluation import * -------------------------------------------------------------------------------- /datasets/evaluation/instance_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import contextlib 3 | import copy 4 | import io 5 | import itertools 6 | import json 7 | import logging 8 | import numpy as np 9 | import os 10 | import pickle 11 | from collections import OrderedDict 12 | import pycocotools.mask as mask_util 13 | import torch 14 | from pycocotools.coco import COCO 15 | from pycocotools.cocoeval import COCOeval 16 | from tabulate import tabulate 17 | 18 | import detectron2.utils.comm as comm 19 | from detectron2.config import CfgNode 20 | from detectron2.data import MetadataCatalog 21 | from detectron2.data.datasets.coco import convert_to_coco_json 22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco 23 | from detectron2.evaluation.fast_eval_api import COCOeval_opt 24 | from detectron2.structures import Boxes, BoxMode, pairwise_iou 25 | from detectron2.utils.file_io import PathManager 26 | from detectron2.utils.logger import create_small_table 27 | 28 | 29 | # modified from COCOEvaluator for instance segmetnat 30 | class InstanceSegEvaluator(COCOEvaluator): 31 | """ 32 | Evaluate AR for object proposals, AP for instance detection/segmentation, AP 33 | for keypoint detection outputs using COCO's metrics. 34 | See http://cocodataset.org/#detection-eval and 35 | http://cocodataset.org/#keypoints-eval to understand its metrics. 36 | The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means 37 | the metric cannot be computed (e.g. due to no predictions made). 38 | 39 | In addition to COCO, this evaluator is able to support any bounding box detection, 40 | instance segmentation, or keypoint detection dataset. 41 | """ 42 | 43 | def _eval_predictions(self, predictions, img_ids=None): 44 | """ 45 | Evaluate predictions. Fill self._results with the metrics of the tasks. 46 | """ 47 | self._logger.info("Preparing results for COCO format ...") 48 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) 49 | tasks = self._tasks or self._tasks_from_predictions(coco_results) 50 | 51 | # unmap the category ids for COCO 52 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): 53 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id 54 | # all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) 55 | # num_classes = len(all_contiguous_ids) 56 | # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 57 | 58 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} 59 | for result in coco_results: 60 | category_id = result["category_id"] 61 | # assert category_id < num_classes, ( 62 | # f"A prediction has class={category_id}, " 63 | # f"but the dataset only has {num_classes} classes and " 64 | # f"predicted class id should be in [0, {num_classes - 1}]." 65 | # ) 66 | assert category_id in reverse_id_mapping, ( 67 | f"A prediction has class={category_id}, " 68 | f"but the dataset only has class ids in {dataset_id_to_contiguous_id}." 69 | ) 70 | result["category_id"] = reverse_id_mapping[category_id] 71 | 72 | if self._output_dir: 73 | file_path = os.path.join(self._output_dir, "coco_instances_results.json") 74 | self._logger.info("Saving results to {}".format(file_path)) 75 | with PathManager.open(file_path, "w") as f: 76 | f.write(json.dumps(coco_results)) 77 | f.flush() 78 | 79 | if not self._do_evaluation: 80 | self._logger.info("Annotations are not available for evaluation.") 81 | return 82 | 83 | self._logger.info( 84 | "Evaluating predictions with {} COCO API...".format( 85 | "unofficial" if self._use_fast_impl else "official" 86 | ) 87 | ) 88 | for task in sorted(tasks): 89 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" 90 | coco_eval = ( 91 | _evaluate_predictions_on_coco( 92 | self._coco_api, 93 | coco_results, 94 | task, 95 | kpt_oks_sigmas=self._kpt_oks_sigmas, 96 | use_fast_impl=self._use_fast_impl, 97 | img_ids=img_ids, 98 | max_dets_per_image=self._max_dets_per_image, 99 | ) 100 | if len(coco_results) > 0 101 | else None # cocoapi does not handle empty results very well 102 | ) 103 | 104 | res = self._derive_coco_results( 105 | coco_eval, task, class_names=self._metadata.get("thing_classes") 106 | ) 107 | self._results[task] = res 108 | -------------------------------------------------------------------------------- /datasets/evaluation/pascal_part_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import contextlib 3 | import copy 4 | import io 5 | import itertools 6 | import json 7 | import logging 8 | import numpy as np 9 | import os 10 | import pickle 11 | from collections import OrderedDict 12 | import pycocotools.mask as mask_util 13 | import torch 14 | from pycocotools.coco import COCO 15 | from pycocotools.cocoeval import COCOeval 16 | from tabulate import tabulate 17 | 18 | import detectron2.utils.comm as comm 19 | from detectron2.config import CfgNode 20 | from detectron2.data import MetadataCatalog 21 | from detectron2.data.datasets.coco import convert_to_coco_json 22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator 23 | from detectron2.structures import Boxes, BoxMode, pairwise_iou 24 | from detectron2.utils.file_io import PathManager 25 | from detectron2.utils.logger import create_small_table 26 | from ..registration.register_pascal_part_all import ( 27 | PASCAL_PART_BASE_CATEGORIES as categories_seen, 28 | PASCAL_PART_NOVEL_CATEGORIES as categories_unseen, 29 | ) 30 | 31 | 32 | class PASCALPARTEvaluator(COCOEvaluator): 33 | """ 34 | PASCALPARTEvaluator on open_vocabulary 35 | """ 36 | 37 | def _derive_coco_results(self, coco_eval, iou_type, class_names=None): 38 | """ 39 | Additionally plot mAP for 'seen classes' and 'unseen classes' 40 | """ 41 | 42 | metrics = { 43 | "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"], 44 | "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"], 45 | "keypoints": ["AP", "AP50", "AP75", "APm", "APl"], 46 | }[iou_type] 47 | 48 | if coco_eval is None: 49 | self._logger.warn("No predictions from the model!") 50 | return {metric: float("nan") for metric in metrics} 51 | 52 | # the standard metrics 53 | results = { 54 | metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan") 55 | for idx, metric in enumerate(metrics) 56 | } 57 | self._logger.info( 58 | "Evaluation results for {}: \n".format(iou_type) + create_small_table(results) 59 | ) 60 | if not np.isfinite(sum(results.values())): 61 | self._logger.info("Some metrics cannot be computed and is shown as NaN.") 62 | 63 | if class_names is None or len(class_names) <= 1: 64 | return results 65 | # Compute per-category AP 66 | # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa 67 | precisions = coco_eval.eval["precision"] 68 | # precision has dims (iou, recall, cls, area range, max dets) 69 | assert len(class_names) == precisions.shape[2] 70 | 71 | seen_names = set([x['name'] for x in categories_seen]) 72 | unseen_names = set([x['name'] for x in categories_unseen]) 73 | results_per_category = [] 74 | results_per_category50 = [] 75 | results_per_category_seen = [] 76 | results_per_category_unseen = [] 77 | results_per_category50_seen = [] 78 | results_per_category50_unseen = [] 79 | for idx, name in enumerate(class_names): 80 | # area range index 0: all area ranges 81 | # max dets index -1: typically 100 per image 82 | precision = precisions[:, :, idx, 0, -1] 83 | precision = precision[precision > -1] 84 | ap = np.mean(precision) if precision.size else float("nan") 85 | results_per_category.append(("{}".format(name), float(ap * 100))) 86 | precision50 = precisions[0, :, idx, 0, -1] 87 | precision50 = precision50[precision50 > -1] 88 | ap50 = np.mean(precision50) if precision50.size else float("nan") 89 | results_per_category50.append(("{}".format(name), float(ap50 * 100))) 90 | if name in seen_names: 91 | results_per_category_seen.append(float(ap * 100)) 92 | results_per_category50_seen.append(float(ap50 * 100)) 93 | if name in unseen_names: 94 | results_per_category_unseen.append(float(ap * 100)) 95 | results_per_category50_unseen.append(float(ap50 * 100)) 96 | 97 | # tabulate it 98 | N_COLS = min(6, len(results_per_category) * 2) 99 | results_flatten = list(itertools.chain(*results_per_category)) 100 | results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) 101 | table = tabulate( 102 | results_2d, 103 | tablefmt="pipe", 104 | floatfmt=".3f", 105 | headers=["category", "AP"] * (N_COLS // 2), 106 | numalign="left", 107 | ) 108 | self._logger.info("Per-category {} AP: \n".format(iou_type) + table) 109 | 110 | N_COLS = min(6, len(results_per_category50) * 2) 111 | results_flatten = list(itertools.chain(*results_per_category50)) 112 | results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) 113 | table = tabulate( 114 | results_2d, 115 | tablefmt="pipe", 116 | floatfmt=".3f", 117 | headers=["category", "AP50"] * (N_COLS // 2), 118 | numalign="left", 119 | ) 120 | self._logger.info("Per-category {} AP50: \n".format(iou_type) + table) 121 | 122 | self._logger.info( 123 | "Seen {} AP: {}".format( 124 | iou_type, 125 | sum(results_per_category_seen) / len(results_per_category_seen), 126 | )) 127 | self._logger.info( 128 | "Unseen {} AP: {}".format( 129 | iou_type, 130 | sum(results_per_category_unseen) / len(results_per_category_unseen), 131 | )) 132 | 133 | self._logger.info( 134 | "Seen {} AP50: {}".format( 135 | iou_type, 136 | sum(results_per_category50_seen) / len(results_per_category50_seen), 137 | )) 138 | self._logger.info( 139 | "Unseen {} AP50: {}".format( 140 | iou_type, 141 | sum(results_per_category50_unseen) / len(results_per_category50_unseen), 142 | )) 143 | 144 | results.update({"AP-" + name: ap for name, ap in results_per_category}) 145 | results["AP-seen"] = sum(results_per_category_seen) / len(results_per_category_seen) 146 | results["AP-unseen"] = sum(results_per_category_unseen) / len(results_per_category_unseen) 147 | results["AP50-seen"] = sum(results_per_category50_seen) / len(results_per_category50_seen) 148 | results["AP50-unseen"] = sum(results_per_category50_unseen) / len(results_per_category50_unseen) 149 | return results -------------------------------------------------------------------------------- /datasets/registration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Micorsoft, Inc. and its affiliates. 2 | from . import ( 3 | register_ade20k_full, 4 | register_ade20k_panoptic, 5 | register_coco_stuff_10k, 6 | register_coco_panoptic_annos_semseg, 7 | register_coco_panoptic_annos_semseg_interactive, 8 | register_coco_panoptic_annos_semseg_interactive_jointboxpoint, 9 | register_ade20k_instance, 10 | # register_object365_od, 11 | # register_sam, 12 | register_sam_mnode, 13 | register_sam_json_val, 14 | register_pascal_part_all, 15 | register_pascal_part_all_interactive, 16 | register_paco_part_all, 17 | register_partimagenet_part_all, 18 | ) 19 | -------------------------------------------------------------------------------- /datasets/registration/register_ade20k_instance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import json 3 | import logging 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | 8 | from detectron2.data import DatasetCatalog, MetadataCatalog 9 | from detectron2.data.datasets.coco import load_coco_json, register_coco_instances 10 | from detectron2.utils.file_io import PathManager 11 | 12 | ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}] 13 | 14 | 15 | _PREDEFINED_SPLITS = { 16 | # point annotations without masks 17 | "ade20k_instance_train": ( 18 | "ADEChallengeData2016/images/training", 19 | "ADEChallengeData2016/ade20k_instance_train.json", 20 | ), 21 | "ade20k_instance_val": ( 22 | "ADEChallengeData2016/images/validation", 23 | "ADEChallengeData2016/ade20k_instance_val.json", 24 | ), 25 | } 26 | 27 | 28 | def _get_ade_instances_meta(): 29 | thing_ids = [k["id"] for k in ADE_CATEGORIES] 30 | assert len(thing_ids) == 100, len(thing_ids) 31 | # Mapping from the incontiguous ADE category id to an id in [0, 99] 32 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 33 | thing_classes = [k["name"] for k in ADE_CATEGORIES] 34 | ret = { 35 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 36 | "thing_classes": thing_classes, 37 | } 38 | return ret 39 | 40 | 41 | def register_all_ade20k_instance(root): 42 | for key, (image_root, json_file) in _PREDEFINED_SPLITS.items(): 43 | # Assume pre-defined datasets live in `./datasets`. 44 | register_coco_instances( 45 | key, 46 | _get_ade_instances_meta(), 47 | os.path.join(root, json_file) if "://" not in json_file else json_file, 48 | os.path.join(root, image_root), 49 | ) 50 | 51 | 52 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 53 | register_all_ade20k_instance(_root) 54 | -------------------------------------------------------------------------------- /datasets/registration/register_bdd100k_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import numpy as np 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import BDD_SEM 17 | 18 | __all__ = ["load_scannet_instances", "register_scannet_context"] 19 | 20 | 21 | def load_bdd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load BDD annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | img_folder = os.path.join(dirname, 'images', '10k', split) 31 | img_pths = sorted(glob.glob(os.path.join(img_folder, '*.jpg'))) 32 | 33 | sem_folder = os.path.join(dirname, 'labels', 'sem_seg', 'masks', split) 34 | sem_pths = sorted(glob.glob(os.path.join(sem_folder, '*.png'))) 35 | 36 | assert len(img_pths) == len(sem_pths) 37 | 38 | dicts = [] 39 | for img_pth, sem_pth in zip(img_pths, sem_pths): 40 | r = { 41 | "file_name": img_pth, 42 | "sem_seg_file_name": sem_pth, 43 | "image_id": img_pth.split('/')[-1].split('.')[0], 44 | } 45 | dicts.append(r) 46 | return dicts 47 | 48 | 49 | def register_bdd_context(name, dirname, split, class_names=BDD_SEM): 50 | DatasetCatalog.register(name, lambda: load_bdd_instances(name, dirname, split, class_names)) 51 | MetadataCatalog.get(name).set( 52 | stuff_classes=class_names, 53 | dirname=dirname, 54 | split=split, 55 | ignore_label=[255], 56 | thing_dataset_id_to_contiguous_id={}, 57 | class_offset=0, 58 | keep_sem_bgd=False 59 | ) 60 | 61 | 62 | def register_all_sunrgbd_seg(root): 63 | SPLITS = [ 64 | ("bdd10k_val_sem_seg", "bdd100k", "val"), 65 | ] 66 | 67 | for name, dirname, split in SPLITS: 68 | register_bdd_context(name, os.path.join(root, dirname), split) 69 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 70 | 71 | 72 | _root = os.getenv("DATASET", "datasets") 73 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_imagenet_cls.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.structures import BoxMode 15 | from detectron2.utils.file_io import PathManager 16 | 17 | from utils.constants import IMAGENET_CLASSES, IMAGENET_FOLDER_NAMES 18 | 19 | __all__ = ["load_imagenet_images", "register_imagenet"] 20 | 21 | 22 | def load_imagenet_images(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 23 | """ 24 | Load ImageNet annotations to Detectron2 format. 25 | 26 | Args: 27 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 28 | split (str): one of "train", "test", "val", "trainval" 29 | class_names: list or tuple of class names 30 | """ 31 | image_folders = sorted(glob.glob(os.path.join(dirname, split, 'n*'))) 32 | 33 | dicts = [] 34 | for image_folder in image_folders: 35 | folder_name = image_folder.split('/')[-1] 36 | image_pths = sorted(glob.glob(os.path.join(image_folder, "*.JPEG"))) 37 | for img_pth in image_pths: 38 | r = { 39 | "file_name": img_pth, 40 | "class_name": IMAGENET_CLASSES[IMAGENET_FOLDER_NAMES.index(folder_name)], 41 | "class_id": IMAGENET_FOLDER_NAMES.index(folder_name), 42 | } 43 | dicts.append(r) 44 | return dicts 45 | 46 | 47 | def register_imagenet(name, dirname, split, year, class_names=IMAGENET_CLASSES): 48 | DatasetCatalog.register(name, lambda: load_imagenet_images(dirname, split, class_names)) 49 | MetadataCatalog.get(name).set( 50 | thing_classes=list(class_names), dirname=dirname, year=year, split=split 51 | ) 52 | 53 | 54 | def register_all_imagenet(root): 55 | SPLITS = [ 56 | ("imagenet_val", "imagenet", "val", "2012"), 57 | ] 58 | for name, dirname, split, year in SPLITS: 59 | register_imagenet(name, os.path.join(root, dirname), split, year) 60 | MetadataCatalog.get(name).evaluator_type = "classification" 61 | 62 | 63 | _root = os.getenv("DATASET", "datasets") 64 | register_all_imagenet(_root) -------------------------------------------------------------------------------- /datasets/registration/register_lvis_eval.py: -------------------------------------------------------------------------------- 1 | from detectron2.data.datasets import get_lvis_instances_meta, register_lvis_instances 2 | from detectron2.data import DatasetCatalog, MetadataCatalog 3 | from xy_utils.lvis_cat import LVIS_CATEGORIES as LVIS_V1_CATEGORIES 4 | import logging 5 | import os 6 | from detectron2.utils.file_io import PathManager 7 | from fvcore.common.timer import Timer 8 | import json 9 | 10 | 11 | 12 | _PREDEFINED_SPLITS_LVIS = { 13 | "lvis_v1": { 14 | "lvis_v1_minival": ("coco/", "coco/annotations/lvis_v1_minival_inserted_image_name.json"), 15 | # "lvis_v1_train": ("coco/", "lvis/lvis_v1_train.json"), 16 | # "lvis_v1_val": ("coco/", "lvis/lvis_v1_val.json"), 17 | # "lvis_v1_test_dev": ("coco/", "lvis/lvis_v1_image_info_test_dev.json"), 18 | # "lvis_v1_test_challenge": ("coco/", "lvis/lvis_v1_image_info_test_challenge.json"), 19 | }, 20 | # "lvis_v0.5": { 21 | # "lvis_v0.5_train": ("coco/", "lvis/lvis_v0.5_train.json"), 22 | # "lvis_v0.5_val": ("coco/", "lvis/lvis_v0.5_val.json"), 23 | # "lvis_v0.5_val_rand_100": ("coco/", "lvis/lvis_v0.5_val_rand_100.json"), 24 | # "lvis_v0.5_test": ("coco/", "lvis/lvis_v0.5_image_info_test.json"), 25 | # }, 26 | # "lvis_v0.5_cocofied": { 27 | # "lvis_v0.5_train_cocofied": ("coco/", "lvis/lvis_v0.5_train_cocofied.json"), 28 | # "lvis_v0.5_val_cocofied": ("coco/", "lvis/lvis_v0.5_val_cocofied.json"), 29 | # }, 30 | } 31 | 32 | def get_lvis_instances_meta_v1(): 33 | assert len(LVIS_V1_CATEGORIES) == 1203 34 | cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES] 35 | assert min(cat_ids) == 1 and max(cat_ids) == len( 36 | cat_ids 37 | ), "Category ids are not in [1, #categories], as expected" 38 | # Ensure that the category list is sorted by id 39 | thing_ids = [k["id"] for k in LVIS_V1_CATEGORIES] 40 | # lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"]) 41 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 42 | # thing_classes = [k["name"] for k in O365_CATEGORIES] 43 | def preprocess_name(name): 44 | name = name.lower().strip() 45 | name = name.replace('_', ' ') 46 | return name 47 | thing_classes = [preprocess_name(k["synonyms"][0]) for k in LVIS_V1_CATEGORIES] 48 | meta = { 49 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 50 | "thing_classes": thing_classes, 51 | } 52 | return meta 53 | 54 | 55 | def register_lvis_instances(name, metadata, json_file, image_root): 56 | """ 57 | Register a dataset in LVIS's json annotation format for instance detection and segmentation. 58 | 59 | Args: 60 | name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train". 61 | metadata (dict): extra metadata associated with this dataset. It can be an empty dict. 62 | json_file (str): path to the json instance annotation file. 63 | image_root (str or path-like): directory which contains all the images. 64 | """ 65 | DatasetCatalog.register(name, lambda: load_lvis_json(image_root, json_file, name)) 66 | MetadataCatalog.get(name).set( 67 | json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata 68 | ) 69 | 70 | 71 | def load_lvis_json(image_root, annot_json, metadata): 72 | """ 73 | Args: 74 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 75 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". 76 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". 77 | Returns: 78 | list[dict]: a list of dicts in Detectron2 standard format. (See 79 | `Using Custom Datasets `_ ) 80 | """ 81 | with PathManager.open(annot_json) as f: 82 | json_info = json.load(f) 83 | 84 | imageid2seg = {} 85 | imageid2box = {} 86 | imageid2lable = {} 87 | for anno in json_info["annotations"]: 88 | image_id = anno['image_id'] 89 | seg = anno["segmentation"] 90 | bbox = anno["bbox"] 91 | label = anno["category_id"] 92 | if image_id not in imageid2seg: 93 | imageid2seg[image_id] = [] 94 | if image_id not in imageid2box: 95 | imageid2box[image_id] = [] 96 | if image_id not in imageid2lable: 97 | imageid2lable[image_id] = [] 98 | imageid2seg[image_id] += [seg] 99 | imageid2box[image_id] += [bbox] 100 | imageid2lable[image_id] += [label] 101 | 102 | ret = [] 103 | cnt_empty = 0 104 | for image in json_info["images"]: 105 | image_file = os.path.join(image_root ,'/'.join(image["coco_url"].split('/')[-2:])) 106 | image_id = image['id'] 107 | if image_id not in imageid2lable: 108 | cnt_empty += 1 109 | continue 110 | ret.append( 111 | { 112 | "file_name": image_file, 113 | "image_id": image_id, 114 | "height": image['height'], 115 | "width": image['width'], 116 | "instance": imageid2seg[image_id], 117 | "box": imageid2box[image_id], 118 | "labels": imageid2lable[image_id], 119 | } 120 | ) 121 | 122 | print("Empty annotations: {}".format(cnt_empty)) 123 | assert len(ret), f"No images found in {image_root}!" 124 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] 125 | return ret 126 | 127 | 128 | def register_all_lvis(root): 129 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items(): 130 | for key, (image_root, json_file) in splits_per_dataset.items(): 131 | register_lvis_instances( 132 | key, 133 | get_lvis_instances_meta_v1(), 134 | os.path.join(root, json_file) if "://" not in json_file else json_file, 135 | os.path.join(root, image_root), 136 | ) 137 | 138 | 139 | _root = os.getenv("DATASET3", "datasets") 140 | register_all_lvis(_root) -------------------------------------------------------------------------------- /datasets/registration/register_paco_part_all.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import os 4 | from detectron2.data import DatasetCatalog, MetadataCatalog 5 | import copy 6 | # from detectron2.data.datasets.register_coco import register_coco_instances 7 | from detectron2.data.datasets.coco import load_coco_json 8 | import json 9 | _root = os.getenv("PACO", "datasets") 10 | json_name = 'os.path.join(_root,"paco/annotations/paco_lvis_v1_val.json")' 11 | if os.path.exists(json_name): 12 | with open(os.path.join(_root,"paco/annotations/paco_lvis_v1_val.json")) as f: 13 | j=json.load(f) 14 | PACO_CATEGORIES=j['categories'] 15 | 16 | 17 | def _get_paco_metadata(key): 18 | # if '_base' in key: 19 | # id_to_name = {x['id']: x['name'] for x in PASCAL_PART_BASE_CATEGORIES} 20 | # else: 21 | id_to_name = {x['id']: x['name'] for x in PACO_CATEGORIES} 22 | 23 | thing_classes_ = [id_to_name[k] for k in sorted(id_to_name)] 24 | PACO_CATEGORIES_=copy.deepcopy(PACO_CATEGORIES) 25 | for cat in PACO_CATEGORIES_: 26 | if ':' not in cat['name']: 27 | cat['name']=cat['name']+':whole' 28 | if '_(' in cat['name']: 29 | cat['name']=cat['name'][:cat['name'].find('_(')]+cat['name'][cat['name'].find(')')+1:] 30 | if '_' in cat['name']: 31 | cat['name']=cat['name'].replace('_',' ') 32 | id_to_name = {x['id']: x['name'] for x in PACO_CATEGORIES_} 33 | thing_dataset_id_to_contiguous_id = { 34 | x: i for i, x in enumerate(sorted(id_to_name))} 35 | thing_classes = [id_to_name[k] for k in sorted(id_to_name)] 36 | 37 | part_classes = [a.split(":")[1].lower() for a in thing_classes] 38 | thing_clases_id_to_part_id={v: sorted(set(part_classes)).index(n) for v, n in enumerate(part_classes)} 39 | whole_classes = [a.split(":")[0].lower() for a in thing_classes] 40 | 41 | no_part_index = sorted(set(part_classes)).index('whole') 42 | thing_classes_id_without_part = [k for k, v in thing_clases_id_to_part_id.items() if no_part_index==v] 43 | 44 | thing_clases_id_to_whole_id={v: sorted(set(whole_classes)).index(n) for v, n in enumerate(whole_classes)} 45 | thing_clases_id_to_flattened_wholepart = {tid: thing_clases_id_to_whole_id[tid]*len(set(part_classes))+pid for tid, pid in thing_clases_id_to_part_id.items()} 46 | return { 47 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 48 | "thing_classes": thing_classes_, 49 | "thing_clases_id_to_part_id": thing_clases_id_to_part_id, 50 | "part_classes": sorted(set(part_classes)), 51 | "thing_clases_id_to_whole_id": thing_clases_id_to_whole_id, 52 | "whole_classes": sorted(set(whole_classes)), 53 | "thing_clases_id_to_flattened_wholepart": thing_clases_id_to_flattened_wholepart, 54 | "thing_classes_id_without_part": thing_classes_id_without_part, 55 | } 56 | 57 | 58 | def register_paco_part_instances(name, metadata, json_file, image_root): 59 | DatasetCatalog.register(name, lambda: load_coco_json( 60 | json_file, image_root, name)) 61 | MetadataCatalog.get(name).set( 62 | json_file=json_file, image_root=image_root, 63 | evaluator_type="pascal_part_interactive", **metadata 64 | ) 65 | 66 | _PACO = { 67 | "paco_train": ("coco", "paco/annotations/paco_lvis_v1_train.json"), 68 | # "pascal_part_train_one": ("pascal_part/VOCdevkit/VOC2010/JPEGImages", "pascal_part/train_one.json"), 69 | "paco_val_inter": ("coco", "paco/annotations/paco_lvis_v1_val_mini.json"), 70 | # "paco_test": ("paco/val2017", "paco/annotations/paco_lvis_v1_val.json"), 71 | # "pascal_part_base_train": ("pascal_part/VOCdevkit/VOC2010/JPEGImages", "pascal_part/train_base.json"), 72 | # "pascal_part_base_train_one": ("pascal_part/VOCdevkit/VOC2010/JPEGImages", "pascal_part/train_base_one.json"), 73 | # "imagenet_voc_parsed": ("imagenet/train", "imagenet/imagenet_voc_image_parsed.json"), 74 | # "imagenet_golden_pascal_parsed": ("imagenet/train", "imagenet/imagenet_golden_pascal_parsed.json"), 75 | # "imagenet_golden_pascal_parsed_swinbase": ("imagenet/train", "imagenet/imagenet_golden_pascal_parsed_swinbase.json"), 76 | } 77 | 78 | 79 | def register_paco_part(root): 80 | for key, (image_root, json_file) in _PACO.items(): 81 | register_paco_part_instances( 82 | key, 83 | _get_paco_metadata(key), 84 | os.path.join(root, json_file) if "://" not in json_file else json_file, 85 | os.path.join(root, image_root), 86 | ) 87 | 88 | if os.path.exists(json_name): 89 | register_paco_part(_root) 90 | -------------------------------------------------------------------------------- /datasets/registration/register_partimagenet_part_all.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import os 4 | from detectron2.data import DatasetCatalog, MetadataCatalog 5 | # from detectron2.data.datasets.register_coco import register_coco_instances 6 | from detectron2.data.datasets.coco import load_coco_json 7 | 8 | PART_IN_CATEGORIES = [{'id': 0, 'name': 'Quadruped Head', 'supercategory': 'Quadruped'}, 9 | {'id': 1, 'name': 'Quadruped Body', 'supercategory': 'Quadruped'}, 10 | {'id': 2, 'name': 'Quadruped Foot', 'supercategory': 'Quadruped'}, 11 | {'id': 3, 'name': 'Quadruped Tail', 'supercategory': 'Quadruped'}, 12 | {'id': 4, 'name': 'Biped Head', 'supercategory': 'Biped'}, 13 | {'id': 5, 'name': 'Biped Body', 'supercategory': 'Biped'}, 14 | {'id': 6, 'name': 'Biped Hand', 'supercategory': 'Biped'}, 15 | {'id': 7, 'name': 'Biped Foot', 'supercategory': 'Biped'}, 16 | {'id': 8, 'name': 'Biped Tail', 'supercategory': 'Biped'}, 17 | {'id': 9, 'name': 'Fish Head', 'supercategory': 'Fish'}, 18 | {'id': 10, 'name': 'Fish Body', 'supercategory': 'Fish'}, 19 | {'id': 11, 'name': 'Fish Fin', 'supercategory': 'Fish'}, 20 | {'id': 12, 'name': 'Fish Tail', 'supercategory': 'Fish'}, 21 | {'id': 13, 'name': 'Bird Head', 'supercategory': 'Bird'}, 22 | {'id': 14, 'name': 'Bird Body', 'supercategory': 'Bird'}, 23 | {'id': 15, 'name': 'Bird Wing', 'supercategory': 'Bird'}, 24 | {'id': 16, 'name': 'Bird Foot', 'supercategory': 'Bird'}, 25 | {'id': 17, 'name': 'Bird Tail', 'supercategory': 'Bird'}, 26 | {'id': 18, 'name': 'Snake Head', 'supercategory': 'Snake'}, 27 | {'id': 19, 'name': 'Snake Body', 'supercategory': 'Snake'}, 28 | {'id': 20, 'name': 'Reptile Head', 'supercategory': 'Reptile'}, 29 | {'id': 21, 'name': 'Reptile Body', 'supercategory': 'Reptile'}, 30 | {'id': 22, 'name': 'Reptile Foot', 'supercategory': 'Reptile'}, 31 | {'id': 23, 'name': 'Reptile Tail', 'supercategory': 'Reptile'}, 32 | {'id': 24, 'name': 'Car Body', 'supercategory': 'Car'}, 33 | {'id': 25, 'name': 'Car Tier', 'supercategory': 'Car'}, 34 | {'id': 26, 'name': 'Car Side Mirror', 'supercategory': 'Car'}, 35 | {'id': 27, 'name': 'Bicycle Body', 'supercategory': 'Bicycle'}, 36 | {'id': 28, 'name': 'Bicycle Head', 'supercategory': 'Bicycle'}, 37 | {'id': 29, 'name': 'Bicycle Seat', 'supercategory': 'Bicycle'}, 38 | {'id': 30, 'name': 'Bicycle Tier', 'supercategory': 'Bicycle'}, 39 | {'id': 31, 'name': 'Boat Body', 'supercategory': 'Boat'}, 40 | {'id': 32, 'name': 'Boat Sail', 'supercategory': 'Boat'}, 41 | {'id': 33, 'name': 'Aeroplane Head', 'supercategory': 'Aeroplane'}, 42 | {'id': 34, 'name': 'Aeroplane Body', 'supercategory': 'Aeroplane'}, 43 | {'id': 35, 'name': 'Aeroplane Engine', 'supercategory': 'Aeroplane'}, 44 | {'id': 36, 'name': 'Aeroplane Wing', 'supercategory': 'Aeroplane'}, 45 | {'id': 37, 'name': 'Aeroplane Tail', 'supercategory': 'Aeroplane'}, 46 | {'id': 38, 'name': 'Bottle Mouth', 'supercategory': 'Bottle'}, 47 | {'id': 39, 'name': 'Bottle Body', 'supercategory': 'Bottle'}] 48 | 49 | 50 | def _get_partimagenet_metadata(key): 51 | # if '_base' in key: 52 | # id_to_name = {x['id']: x['name'] for x in PASCAL_PART_BASE_CATEGORIES} 53 | # else: 54 | id_to_name = {x['id']: x['name'] for x in PART_IN_CATEGORIES} 55 | thing_dataset_id_to_contiguous_id = { 56 | x: i for i, x in enumerate(sorted(id_to_name))} 57 | thing_classes = [id_to_name[k] for k in sorted(id_to_name)] 58 | 59 | part_classes = [a.split(" ")[1].lower() for a in thing_classes] 60 | thing_clases_id_to_part_id = {v: sorted(set(part_classes)).index(n) for v, n in enumerate(part_classes)} 61 | whole_classes = [a.split(" ")[0].lower() for a in thing_classes] 62 | thing_clases_id_to_whole_id = {v: sorted(set(whole_classes)).index(n) for v, n in enumerate(whole_classes)} 63 | thing_clases_id_to_flattened_wholepart = {tid: thing_clases_id_to_whole_id[tid] * len(set(part_classes)) + pid for 64 | tid, pid in thing_clases_id_to_part_id.items()} 65 | return { 66 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 67 | "thing_classes": thing_classes, 68 | "thing_clases_id_to_part_id": thing_clases_id_to_part_id, 69 | "part_classes": sorted(set(part_classes)), 70 | "thing_clases_id_to_whole_id": thing_clases_id_to_whole_id, 71 | "whole_classes": sorted(set(whole_classes)), 72 | "thing_clases_id_to_flattened_wholepart": thing_clases_id_to_flattened_wholepart, 73 | } 74 | 75 | 76 | def register_partimagenet_part_instances(name, metadata, json_file, image_root): 77 | DatasetCatalog.register(name, lambda: load_coco_json( 78 | json_file, image_root, name)) 79 | MetadataCatalog.get(name).set( 80 | json_file=json_file, image_root=image_root, 81 | evaluator_type="pascal_part_interactive", **metadata 82 | ) 83 | 84 | 85 | _PART_IN = { 86 | "partimagenet_train": ("imagenet/train", "partimagenet/train_format.json"), 87 | # "pascal_part_train_one": ("pascal_part/VOCdevkit/VOC2010/JPEGImages", "pascal_part/train_one.json"), 88 | "partimagenet_val_inter": ("imagenet/val", "partimagenet/val_format_mini.json"), 89 | } 90 | 91 | 92 | def register_partimagenet_part(root): 93 | for key, (image_root, json_file) in _PART_IN.items(): 94 | register_partimagenet_part_instances( 95 | key, 96 | _get_partimagenet_metadata(key), 97 | os.path.join(root, json_file) if "://" not in json_file else json_file, 98 | os.path.join(root, image_root), 99 | ) 100 | 101 | 102 | _root = os.getenv("PART_IN", "datasets") 103 | register_partimagenet_part(_root) 104 | -------------------------------------------------------------------------------- /datasets/registration/register_refcoco_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | import json 9 | import os 10 | import collections 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.data.datasets import load_sem_seg 14 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 15 | from detectron2.utils.file_io import PathManager 16 | 17 | 18 | _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { 19 | # "refcocog_train_umd": ( 20 | # "coco/train2017", # image_root 21 | # "coco/annotations/refcocog_umd_train.json", # annot_root 22 | # ), 23 | # "refcocog_val_google": ( 24 | # "coco/train2017", # image_root 25 | # "coco/annotations/refcocog_google.json", # annot_root 26 | # ), 27 | # "refcocop_val_unc": ( 28 | # "coco/train2017", # image_root 29 | # "coco/annotations/refcocop_unc.json", # annot_root 30 | # ), 31 | # "refcoco_val_unc": ( 32 | # "coco/train2017", # image_root 33 | # "coco/annotations/refcoco_unc.json", # annot_root 34 | # ), 35 | "refcocog_val_umd": ( 36 | "coco/train2017", # image_root 37 | "coco/annotations/refcocog_umd_val.json", # annot_root 38 | ), 39 | } 40 | 41 | 42 | def get_metadata(): 43 | meta = {} 44 | return meta 45 | 46 | 47 | def load_refcoco_json(image_root, annot_json, metadata): 48 | """ 49 | Args: 50 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 51 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". 52 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". 53 | Returns: 54 | list[dict]: a list of dicts in Detectron2 standard format. (See 55 | `Using Custom Datasets `_ ) 56 | """ 57 | 58 | with PathManager.open(annot_json) as f: 59 | json_info = json.load(f) 60 | 61 | # build dictionary for grounding 62 | grd_dict = collections.defaultdict(list) 63 | for grd_ann in json_info['annotations']: 64 | image_id = int(grd_ann["image_id"]) 65 | grd_dict[image_id].append(grd_ann) 66 | 67 | ret = [] 68 | for image in json_info["images"]: 69 | image_id = int(image["id"]) 70 | image_file = os.path.join(image_root, image['file_name']) 71 | grounding_anno = grd_dict[image_id] 72 | ret.append( 73 | { 74 | "file_name": image_file, 75 | "image_id": image_id, 76 | "grounding_info": grounding_anno, 77 | } 78 | ) 79 | assert len(ret), f"No images found in {image_root}!" 80 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] 81 | return ret 82 | 83 | 84 | def register_refcoco( 85 | name, metadata, image_root, annot_json): 86 | DatasetCatalog.register( 87 | name, 88 | lambda: load_refcoco_json(image_root, annot_json, metadata), 89 | ) 90 | MetadataCatalog.get(name).set( 91 | image_root=image_root, 92 | json_file=annot_json, 93 | evaluator_type="grounding_refcoco", 94 | ignore_label=255, 95 | label_divisor=1000, 96 | **metadata, 97 | ) 98 | 99 | 100 | def register_all_refcoco(root): 101 | for ( 102 | prefix, 103 | (image_root, annot_root), 104 | ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): 105 | register_refcoco( 106 | prefix, 107 | get_metadata(), 108 | os.path.join(root, image_root), 109 | os.path.join(root, annot_root), 110 | ) 111 | 112 | 113 | _root = os.getenv("DATASET", "datasets") 114 | register_all_refcoco(_root) 115 | -------------------------------------------------------------------------------- /datasets/registration/register_sam_json.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft, Inc. and its affiliates. 2 | # Modified by Xueyan Zou and Jianwei Yang. 3 | import json 4 | import os 5 | import collections 6 | import glob 7 | import torch 8 | 9 | from detectron2.data import DatasetCatalog, MetadataCatalog 10 | from detectron2.data.datasets import load_sem_seg 11 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 12 | from detectron2.utils.file_io import PathManager 13 | 14 | 15 | _PREDEFINED_SPLITS_SAM_RAW = { 16 | "sam_train": ( 17 | "meta_sa", 18 | (901,910) 19 | ), 20 | "sam_minitrain": ( 21 | "meta_sa", 22 | (0,12) 23 | ), 24 | "sam_val": ( 25 | "meta_sa", 26 | (901,902) 27 | ), 28 | "sam_minival": ( 29 | "meta_sa", 30 | (998,999) 31 | ), 32 | } 33 | 34 | 35 | def load_sam_instances(name: str, dirname: str, id_range: tuple): 36 | """ 37 | Load SAM detection annotations to Detectron2 format. 38 | 39 | Args: 40 | name: name of split 41 | dirname: dataset directory path 42 | id_range: (start, end) tuple for dataset subfolders 43 | """ 44 | dicts = [] 45 | for id in range(*id_range): 46 | subfolder = os.path.join(dirname, 'sa_%06d' % id, 'image_list.da') 47 | dicts += torch.load(subfolder) 48 | return dicts 49 | 50 | def register_sam(name, dirname, id_range): 51 | DatasetCatalog.register("{}".format(name), lambda: load_sam_instances(name, dirname, id_range)) 52 | MetadataCatalog.get("{}".format(name)).set( 53 | dirname=dirname, 54 | thing_dataset_id_to_contiguous_id={}, 55 | ) 56 | 57 | def register_all_sam(root): 58 | for ( 59 | prefix, 60 | (image_root, id_range), 61 | ) in _PREDEFINED_SPLITS_SAM_RAW.items(): 62 | register_sam( 63 | prefix, 64 | os.path.join(root, image_root), 65 | id_range 66 | ) 67 | 68 | _root = os.getenv("SAM_JSON", "datasets") 69 | register_all_sam(_root) -------------------------------------------------------------------------------- /datasets/registration/register_sam_json_val.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft, Inc. and its affiliates. 2 | # Modified by Xueyan Zou and Jianwei Yang. 3 | import json 4 | import os 5 | import collections 6 | import glob 7 | import torch 8 | 9 | from detectron2.data import DatasetCatalog, MetadataCatalog 10 | from detectron2.data.datasets import load_sem_seg 11 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 12 | from detectron2.utils.file_io import PathManager 13 | 14 | 15 | _PREDEFINED_SPLITS_SAM_RAW = { 16 | "sam_minival": ( 17 | "sam_val2", 18 | (-1,-1) 19 | ), 20 | } 21 | 22 | 23 | def load_sam_instances(name: str, dirname: str, id_range: tuple): 24 | """ 25 | Load SAM detection annotations to Detectron2 format. 26 | 27 | Args: 28 | name: name of split 29 | dirname: dataset directory path 30 | id_range: (start, end) tuple for dataset subfolders 31 | """ 32 | dicts = [] 33 | # for id in range(*id_range): 34 | subfolder = os.path.join(dirname, 'image_list.da') 35 | dicts += torch.load(subfolder) 36 | return dicts 37 | 38 | def register_sam(name, dirname, id_range): 39 | DatasetCatalog.register("{}".format(name), lambda: load_sam_instances(name, dirname, id_range)) 40 | MetadataCatalog.get("{}".format(name)).set( 41 | dirname=dirname, 42 | evaluator_type = "sam_interactive", 43 | thing_dataset_id_to_contiguous_id={}, 44 | ) 45 | 46 | def register_all_sam(root): 47 | for ( 48 | prefix, 49 | (image_root, id_range), 50 | ) in _PREDEFINED_SPLITS_SAM_RAW.items(): 51 | register_sam( 52 | prefix, 53 | os.path.join(root, image_root), 54 | id_range 55 | ) 56 | 57 | _root = os.getenv("SAM_JSON", "/home/lifeng/") 58 | register_all_sam(_root) -------------------------------------------------------------------------------- /datasets/registration/register_scannet_semseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | import numpy as np 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.structures import BoxMode 15 | from detectron2.utils.file_io import PathManager 16 | 17 | from utils.constants import SCAN_37, SCAN_40, SCAN_20 18 | 19 | __all__ = ["load_scannet_instances", "register_scannet_context"] 20 | 21 | name2folder = {"scannet_41_val_seg": "label41", 22 | "scannet_38_val_seg": "label38", 23 | "scannet_21_val_seg": "label21",} 24 | 25 | name2class = {"scannet_41_val_seg": SCAN_40, 26 | "scannet_38_val_seg": SCAN_37, 27 | "scannet_21_val_seg": SCAN_20} 28 | 29 | 30 | def load_scannet_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 31 | """ 32 | Load ScanNet annotations to Detectron2 format. 33 | 34 | Args: 35 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 36 | split (str): one of "train", "test", "val", "trainval" 37 | class_names: list or tuple of class names 38 | """ 39 | with PathManager.open(os.path.join(dirname, "meta", split + ".txt")) as f: 40 | fileids = np.loadtxt(f, dtype=np.str) 41 | 42 | dicts = [] 43 | for field in fileids: 44 | image_dir = os.path.join(dirname, 'images', field[0]) 45 | semseg_dir = image_dir.replace('color', name2folder[name]).replace('jpg', 'png') 46 | r = { 47 | "file_name": image_dir, 48 | "sem_seg_file_name": semseg_dir, 49 | "image_id": semseg_dir.split('/')[-3] + semseg_dir.split('/')[-1].split('.')[0], 50 | } 51 | dicts.append(r) 52 | return dicts 53 | 54 | 55 | def register_scannet_context(name, dirname, split, class_names=name2class): 56 | DatasetCatalog.register(name, lambda: load_scannet_instances(name, dirname, split, class_names)) 57 | MetadataCatalog.get(name).set( 58 | stuff_classes=class_names[name], 59 | dirname=dirname, 60 | split=split, 61 | ignore_label=[0], 62 | thing_dataset_id_to_contiguous_id={}, 63 | class_offset=1, 64 | keep_sem_bgd=False 65 | ) 66 | 67 | 68 | def register_all_sunrgbd_seg(root): 69 | SPLITS = [ 70 | ("scannet_41_val_seg", "scannet_frames_25k", "val"), 71 | ("scannet_38_val_seg", "scannet_frames_25k", "val"), 72 | ("scannet_21_val_seg", "scannet_frames_25k", "val"), 73 | ] 74 | 75 | for name, dirname, split in SPLITS: 76 | register_scannet_context(name, os.path.join(root, dirname), split) 77 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 78 | 79 | 80 | _root = os.getenv("DATASET", "datasets") 81 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_sunrgbd_semseg.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # -------------------------------------------------------- 4 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 5 | # Copyright (c) 2022 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 8 | # -------------------------------------------------------- 9 | import numpy as np 10 | import os 11 | import glob 12 | from typing import List, Tuple, Union 13 | 14 | from detectron2.data import DatasetCatalog, MetadataCatalog 15 | from detectron2.structures import BoxMode 16 | from detectron2.utils.file_io import PathManager 17 | 18 | from utils.constants import SUN_RGBD_37 19 | 20 | __all__ = ["load_sunrgbd_instances", "register_sunrgbd_context"] 21 | 22 | def load_sunrgbd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 23 | """ 24 | Load SUN-RGBD detection annotations to Detectron2 format. 25 | 26 | Args: 27 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 28 | split (str): one of "train", "test", "val", "trainval" 29 | class_names: list or tuple of class names 30 | """ 31 | if split == 'val': 32 | split = 'test' 33 | 34 | # Needs to read many small annotation files. Makes sense at local 35 | image_pths = sorted(glob.glob(os.path.join(dirname, 'image', split, '*.jpg'))) 36 | semseg_pths = sorted(glob.glob(os.path.join(dirname, 'label37', split, '*.png'))) 37 | 38 | assert len(image_pths) == len(semseg_pths) 39 | 40 | dicts = [] 41 | for image_dir, semseg_dir in zip(image_pths, semseg_pths): 42 | r = { 43 | "file_name": image_dir, 44 | "sem_seg_file_name": semseg_dir, 45 | "image_id": semseg_dir.split('/')[-1].split('.')[0], 46 | } 47 | dicts.append(r) 48 | return dicts 49 | 50 | 51 | def register_sun_context(name, dirname, split, class_names=SUN_RGBD_37): 52 | DatasetCatalog.register(name, lambda: load_sunrgbd_instances(name, dirname, split, class_names)) 53 | MetadataCatalog.get(name).set( 54 | stuff_classes=class_names, 55 | dirname=dirname, 56 | split=split, 57 | ignore_label=[0], 58 | thing_dataset_id_to_contiguous_id={}, 59 | class_offset=1, 60 | keep_sem_bgd=False 61 | ) 62 | 63 | 64 | def register_all_sunrgbd_seg(root): 65 | SPLITS = [ 66 | ("sunrgbd_37_val_seg", "sun_rgbd", "val"), 67 | ] 68 | 69 | for name, dirname, split in SPLITS: 70 | register_sun_context(name, os.path.join(root, dirname), split) 71 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 72 | 73 | 74 | _root = os.getenv("DATASET", "datasets") 75 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_vlp_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | import os 9 | import logging 10 | 11 | from detectron2.data import DatasetCatalog, MetadataCatalog 12 | import pyarrow as pa 13 | 14 | _PREDEFINED_SPLITS_PRETRAIN = { 15 | # filt coco2017 val 16 | # "vlp_train": ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] + ["code224_vg.arrow"] + [f"code224_sbu_{i}.arrow" for i in range(9)] + [f"code224_conceptual_caption_train_{i}.arrow" for i in range(31)], 17 | # "vlp_val": ["coco_caption_karpathy_test.arrow"], 18 | # "vlp_captioning_val": ["coco_caption_karpathy_test.arrow"], 19 | # "vlp_val2017": ["coco_caption_karpathy_val2017.arrow"], 20 | # "vlp_captioning_val2017": ["coco_caption_karpathy_val2017.arrow"], 21 | # filt coco2017 and refcocog umd val 22 | # "vlp_train": ["filtrefval2017_coco_caption_karpathy_restval.arrow", "filtrefval2017_coco_caption_karpathy_train.arrow", "filtrefval2017_coco_caption_karpathy_val.arrow"] + ["code224_vg.arrow"] + [f"code224_sbu_{i}.arrow" for i in range(9)] + [f"code224_conceptual_caption_train_{i}.arrow" for i in range(31)], 23 | # "vlp_val": ["coco_caption_karpathy_test.arrow"], 24 | # "vlp_captioning_val": ["coco_caption_karpathy_test.arrow"], 25 | # "vlp_val2017": ["coco_caption_karpathy_val2017.arrow"], 26 | # "vlp_captioning_val2017": ["coco_caption_karpathy_val2017.arrow"], 27 | # the following is for local testing 28 | "vlp_train": ["coco_caption_karpathy_test.arrow"], 29 | "vlp_val": ["coco_caption_karpathy_test.arrow"], 30 | "vlp_captioning_val": ["coco_caption_karpathy_test.arrow"], 31 | } 32 | 33 | def get_metadata(name): 34 | if name in ['vlp_captioning_val', 'vlp_captioning_val2017']: 35 | return {'gt_json': os.path.join(_coco_root, 'coco/annotations/captions_val2014.json')} 36 | else: 37 | return {} 38 | 39 | evaluator_mapper = {'vlp_val': 'retrieval', 'vlp_train': 'retrieval', 'vlp_captioning_val': 'captioning', 'vlp_val2017': 'retrieval', 'vlp_captioning_val2017': 'captioning'} 40 | def load_pretrain_arrows(root, arrow_paths): 41 | """ 42 | Args: 43 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 44 | Returns: 45 | list[dict]: a list of dicts in Detectron2 standard format. (See 46 | `Using Custom Datasets `_ ) 47 | """ 48 | arrs = [] 49 | for arrow_path in arrow_paths: 50 | arr = pa.ipc.RecordBatchFileReader( 51 | pa.memory_map(os.path.join(root, arrow_path), "r") 52 | ).read_all() 53 | 54 | arrs.append(arr) 55 | return arrs 56 | 57 | def load_pretrain_data(arrow_root, meta, name, pretrain_arrows): 58 | ret = [] 59 | 60 | image_id = 0 61 | arr_id = 0 62 | for arr in pretrain_arrows: 63 | arr_len = len(arr) 64 | cur_id = 0 65 | for i in range(arr_len): 66 | captions = arr['caption'][i].as_py() 67 | image_id = arr['image_id'][i].as_py() 68 | if not isinstance(image_id, int): 69 | image_id = int(image_id.split('_')[-1].split('.')[0]) 70 | if 'val' in name: 71 | ret.append( { 72 | "image_id": image_id, 73 | "captions": captions, 74 | "arr_id": arr_id, 75 | "cur_id": cur_id, 76 | }) 77 | else: 78 | for caption in captions: 79 | ret.append( { 80 | "image_id": image_id, 81 | "captions": [caption], 82 | "arr_id": arr_id, 83 | "cur_id": cur_id, 84 | }) 85 | cur_id += 1 86 | image_id += 1 87 | 88 | arr_id += 1 89 | 90 | assert len(ret), f"No images found in pretraining" 91 | return ret 92 | 93 | 94 | def register_pretrain( 95 | name, metadata, arrow_root, arrow_paths 96 | ): 97 | # the name is "coco_2017_train/val_caption_only" 98 | semantic_name = name 99 | arrow_root = os.path.join(arrow_root, 'pretrain_arrows_code224') 100 | if os.path.exists(arrow_root): 101 | pretrain_arrows = load_pretrain_arrows(arrow_root, arrow_paths) 102 | DatasetCatalog.register( 103 | semantic_name, 104 | lambda: load_pretrain_data(arrow_root, metadata, name, pretrain_arrows), 105 | ) 106 | MetadataCatalog.get(semantic_name).set( 107 | arrow_root=arrow_root, 108 | evaluator_type=evaluator_mapper[name], 109 | arrows=pretrain_arrows, 110 | **metadata, 111 | ) 112 | else: 113 | logger = logging.getLogger(__name__) 114 | logger.warning("WARNING: Cannot find VLPreDataset. Make sure datasets are accessible if you want to use them for training or evaluation.") 115 | 116 | def register_all_pretrain(root): 117 | for ( 118 | prefix, 119 | arrow_paths, 120 | ) in _PREDEFINED_SPLITS_PRETRAIN.items(): 121 | register_pretrain( 122 | prefix, 123 | get_metadata(prefix), 124 | root, 125 | arrow_paths, 126 | ) 127 | 128 | 129 | # _root = os.getenv("VLDATASET", "datasets") #may need a different root name? 130 | _root = os.getenv("DATASET2", "datasets") #may need a different root name? 131 | _coco_root = os.getenv("DATASET", "datasets") #may need a different root name? 132 | register_all_pretrain(_root) -------------------------------------------------------------------------------- /datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .semseg_loader import * -------------------------------------------------------------------------------- /datasets/utils/semseg_loader.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import scipy.io 3 | import numpy as np 4 | 5 | def load_semseg(filename, loader_type): 6 | if loader_type == 'PIL': 7 | semseg = np.array(Image.open(filename), dtype=np.int) 8 | elif loader_type == 'MAT': 9 | semseg = scipy.io.loadmat(filename)['LabelMap'] 10 | return semseg -------------------------------------------------------------------------------- /datasets/utils/tsv/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-16 16:56:22 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2021-08-16 17:00:28 6 | 7 | from .io_common import FileProgressingbar, img_from_base64, generate_lineidx 8 | from .tsv_io import TSVFile 9 | 10 | __all__ = [ 11 | 'FileProgressingbar', 'img_from_base64', 'generate_lineidx', 'TSVFile' 12 | ] -------------------------------------------------------------------------------- /datasets/utils/tsv/io_common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-13 14:35:27 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2022-04-24 11:38:58 6 | 7 | import os 8 | import base64 9 | from io import BytesIO 10 | from PIL import Image 11 | 12 | import cv2 13 | import yaml 14 | import progressbar 15 | import numpy as np 16 | import torchvision.transforms as T 17 | 18 | class FileProgressingbar: 19 | fileobj = None 20 | pbar = None 21 | def __init__(self, fileobj, msg): 22 | fileobj.seek(0, os.SEEK_END) 23 | flen = fileobj.tell() 24 | fileobj.seek(0, os.SEEK_SET) 25 | self.fileobj = fileobj 26 | widgets = [msg, progressbar.AnimatedMarker(), ' ', progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()] 27 | self.pbar = progressbar.ProgressBar(widgets=widgets, maxval=flen).start() 28 | 29 | def update(self): 30 | self.pbar.update(self.fileobj.tell()) 31 | 32 | 33 | def img_from_base64(imagestring): 34 | jpgbytestring = base64.b64decode(imagestring) 35 | image = BytesIO(jpgbytestring) 36 | image = Image.open(image).convert("RGB") 37 | return image 38 | 39 | # jpgbytestring = base64.b64decode(imagestring) 40 | # nparr = np.frombuffer(jpgbytestring, np.uint8) 41 | # try: 42 | # r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 43 | # # r = cv2.cvtColor(r, cv2.COLOR_BGR2RGB) 44 | # return r 45 | # except: 46 | # return None 47 | 48 | 49 | def generate_lineidx(filein, idxout): 50 | assert not os.path.isfile(idxout) 51 | with open(filein, 'r') as tsvin, open(idxout, 'w') as tsvout: 52 | bar = FileProgressingbar(tsvin, 'Generating lineidx {0}: '.format(idxout)) 53 | fsize = os.fstat(tsvin.fileno()).st_size 54 | fpos = 0 55 | while fpos != fsize: 56 | tsvout.write(str(fpos)+"\n") 57 | tsvin.readline() 58 | fpos = tsvin.tell() 59 | bar.update() 60 | -------------------------------------------------------------------------------- /datasets/utils/tsv/tsv_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-13 14:26:21 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2022-08-17 00:57:51 6 | import time 7 | import os 8 | import os.path as op 9 | from .io_common import generate_lineidx, FileProgressingbar 10 | 11 | 12 | class TSVFile(object): 13 | def __init__(self, tsv_file, silence=True): 14 | self.tsv_file = tsv_file 15 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' 16 | 17 | self.label_file = op.splitext(tsv_file)[0] + '.label' 18 | self.label_lineidx = op.splitext(tsv_file)[0] + '.label.lineidx' 19 | 20 | if os.path.exists(self.label_file): 21 | self.split_label = True 22 | else: 23 | self.split_label = False 24 | 25 | self._fp = None 26 | self._lineidx = None 27 | 28 | self._label_fp = None 29 | self._label_lineidx = None 30 | 31 | self.pid = None 32 | self.silence = silence 33 | self._ensure_lineidx_loaded() 34 | 35 | def num_rows(self): 36 | return len(self._lineidx) 37 | 38 | def seek(self, idx): 39 | self._ensure_tsv_opened() 40 | pos = self._lineidx[idx] 41 | self._fp.seek(pos) 42 | tsv_info = [s.strip() for s in self._fp.readline().split('\t')] 43 | 44 | if self.split_label: 45 | label_pos = self._label_lineidx[idx] 46 | self._label_fp.seek(label_pos) 47 | label_info = [s.strip() for s in self._label_fp.readline().split('\t')] 48 | 49 | assert tsv_info[0] == label_info[0] 50 | tsv_info = [tsv_info[0], label_info[-1], tsv_info[-1]] 51 | 52 | return tsv_info 53 | 54 | def close(self): 55 | if self._fp is not None: 56 | self._fp.close() 57 | del self._fp 58 | del self._lineidx 59 | 60 | self._fp = None 61 | self._lineidx = None 62 | 63 | def _ensure_lineidx_loaded(self): 64 | if not op.isfile(self.lineidx) and not op.islink(self.lineidx): 65 | generate_lineidx(self.tsv_file, self.lineidx) 66 | 67 | if self._lineidx is None: 68 | with open(self.lineidx, 'r') as fp: 69 | lines = fp.readlines() 70 | self._lineidx = [int(i.strip().split()[0]) for i in lines] 71 | 72 | if self.split_label: 73 | with open(self.label_lineidx, 'r') as fp: 74 | lines = fp.readlines() 75 | self._label_lineidx = [int(i.strip().split()[0]) for i in lines] 76 | 77 | 78 | def _ensure_tsv_opened(self): 79 | self._ensure_lineidx_loaded() 80 | if self._fp is None: 81 | self._fp = open(self.tsv_file, 'r') 82 | self.pid = os.getpid() 83 | 84 | if self.split_label: 85 | self._label_fp = open(self.label_file, 'r') 86 | 87 | if self.pid != os.getpid(): 88 | print('re-open {} because the process id changed'.format(self.tsv_file)) 89 | self._fp = open(self.tsv_file, 'r') 90 | self.pid = os.getpid() 91 | 92 | if self.split_label: 93 | self._label_fp = open(self.label_file, 'r') 94 | -------------------------------------------------------------------------------- /demo_auto_generation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Hao Zhang (hzhangcx@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | 9 | import gradio as gr 10 | import torch 11 | import argparse 12 | 13 | # from gradio import processing_utils 14 | from semantic_sam.BaseModel import BaseModel 15 | from semantic_sam import build_model 16 | from utils.dist import init_distributed_mode 17 | from utils.arguments import load_opt_from_config_file 18 | from utils.constants import COCO_PANOPTIC_CLASSES 19 | 20 | from tasks import interactive_infer_image_idino_m2m_auto, prompt_switch 21 | 22 | def parse_option(): 23 | parser = argparse.ArgumentParser('SemanticSAM Demo', add_help=False) 24 | parser.add_argument('--conf_files', default="configs/semantic_sam_only_sa-1b_swinL.yaml", metavar="FILE", help='path to config file', ) 25 | parser.add_argument('--ckpt', default="", metavar="FILE", help='path to ckpt', ) 26 | args = parser.parse_args() 27 | 28 | return args 29 | 30 | ''' 31 | build args 32 | ''' 33 | args = parse_option() 34 | 35 | 36 | cur_model = 'None' 37 | 38 | ''' 39 | build model 40 | ''' 41 | 42 | model=None 43 | model_size=None 44 | ckpt=None 45 | cfgs={'T':"configs/semantic_sam_only_sa-1b_swinT.yaml", 46 | 'L':"configs/semantic_sam_only_sa-1b_swinL.yaml"} 47 | 48 | sam_cfg=cfgs['L'] 49 | opt = load_opt_from_config_file(sam_cfg) 50 | model_sam = BaseModel(opt, build_model(opt)).from_pretrained(args.ckpt).eval().cuda() 51 | 52 | @torch.no_grad() 53 | def inference(image,level=[0],*args, **kwargs): 54 | if level == 'All Prompt': 55 | level = [1, 2, 3, 4, 5, 6] 56 | else: 57 | level = [level.split(' ')[-1]] 58 | print(level) 59 | text_size, hole_scale, island_scale=640,100,100 60 | text, text_part, text_thresh='','','0.0' 61 | with torch.autocast(device_type='cuda', dtype=torch.float16): 62 | semantic=False 63 | model=model_sam 64 | a= interactive_infer_image_idino_m2m_auto(model, image,level,text,text_part,text_thresh,text_size,hole_scale,island_scale,semantic, *args, **kwargs) 65 | return a 66 | 67 | 68 | 69 | class ImageMask(gr.components.Image): 70 | 71 | is_template = True 72 | 73 | def __init__(self, **kwargs): 74 | super().__init__(source="upload", **kwargs) 75 | 76 | def preprocess(self, x): 77 | return super().preprocess(x) 78 | 79 | 80 | 81 | 82 | ''' 83 | launch app 84 | ''' 85 | title = "SEMANTIC-SAM: SEGMENT AND RECOGNIZE ANYTHING AT ANY GRANULARITY" 86 | 87 | article = "The Demo is Run on SEMANTIC SAM." 88 | 89 | from detectron2.data import MetadataCatalog 90 | from utils.constants import COCO_PANOPTIC_CLASSES 91 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 92 | all_classes = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES] 93 | all_parts=['arm', 'beak', 'body', 'cap', 'door', 'ear', 'eye', 'foot', 'hair', 'hand', 'handlebar', 'head', 'headlight', 'horn', 'leg', 'license plate', 'mirror', 'mouth', 'muzzle', 'neck', 'nose', 'paw', 'plant', 'pot', 'saddle', 'tail', 'torso', 'wheel', 'window', 'wing'] 94 | 95 | demo = gr.Blocks() 96 | image=ImageMask(label="Click on Image (Please only click one point, or our model will take the center of all points as the clicked location. Remember to clear the click after each interaction, or we will take the center of the current click and previous ones as the clicked location.)",type="pil",brush_radius=15.0).style(height=512) 97 | text_model_level=gr.components.Textbox(label="Output levels",value="0,1,2,3,4,5",visible=True) 98 | text_model_select=gr.components.Radio(['Prompt 1', 'Prompt 2', 'Prompt 3', 'Prompt 4', 'Prompt 5', 'Prompt 6', 'All Prompt',],value='All Prompt', label="Our model learns 6 granularity prompts. [1-6] indicates output granularity from largest to smallest using different prompts. 'All prompt' means using all 6 granularity prompts.") 99 | image_out=gr.components.Image(label="Auto generation",type="pil",brush_radius=15.0).style(height=512) 100 | 101 | title=''' 102 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 103 | 104 | # [[Read our arXiv Paper](https://arxiv.org/pdf/2307.04767.pdf)\]   \[[Github page](https://github.com/UX-Decoder/Semantic-SAM)\] 105 | 106 | # Auto generation demo. 107 | ''' 108 | def change_vocab(choice): 109 | if choice: 110 | return gr.update(visible=True) 111 | else: 112 | return gr.update(visible=False) 113 | 114 | 115 | with demo: 116 | with gr.Row(): 117 | with gr.Column(scale=9.0): 118 | generation_tittle = gr.Markdown(title) 119 | with gr.Row(scale=20.0): 120 | image.render() 121 | with gr.Column(scale=1.0): 122 | text_model_select.render() 123 | example1 = gr.Examples( 124 | examples=[ 125 | ["examples/levels_dog.png"], 126 | 127 | ], 128 | inputs=image, 129 | label='Example output of using different prompts for one image, output masks are from semantic, instance, to part level', 130 | 131 | cache_examples=False, 132 | ) 133 | example = gr.Examples( 134 | examples=[ 135 | ["examples/tank.png"], 136 | ["examples/castle.png"], 137 | ["examples/dog.jpg"], 138 | ["examples/fries1.png"], 139 | ["examples/4.png"], 140 | ["examples/5.png"], 141 | ["examples/corgi2.jpg"], 142 | ["examples/minecraft2.png"], 143 | ["examples/ref_cat.jpeg"], 144 | ["examples/img.png"], 145 | 146 | ], 147 | inputs=image, 148 | 149 | cache_examples=False, 150 | ) 151 | with gr.Row(scale=2.0): 152 | clearBtn = gr.ClearButton( 153 | components=[image]) 154 | runBtn = gr.Button("Run") 155 | 156 | with gr.Row(scale=9.0): 157 | image_out.render() 158 | 159 | 160 | title = title, 161 | article = article, 162 | allow_flagging = 'never', 163 | runBtn.click(inference, inputs=[image, text_model_select], 164 | outputs = image_out) 165 | 166 | 167 | 168 | demo.queue().launch(share=True,server_port=6081) 169 | 170 | -------------------------------------------------------------------------------- /examples/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/4.png -------------------------------------------------------------------------------- /examples/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/5.png -------------------------------------------------------------------------------- /examples/castle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/castle.png -------------------------------------------------------------------------------- /examples/corgi1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/corgi1.webp -------------------------------------------------------------------------------- /examples/corgi2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/corgi2.jpg -------------------------------------------------------------------------------- /examples/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/dog.jpg -------------------------------------------------------------------------------- /examples/fries1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/fries1.png -------------------------------------------------------------------------------- /examples/fries2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/fries2.png -------------------------------------------------------------------------------- /examples/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/img.png -------------------------------------------------------------------------------- /examples/levels_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/levels_dog.png -------------------------------------------------------------------------------- /examples/minecraft1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/minecraft1.jpg -------------------------------------------------------------------------------- /examples/minecraft2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/minecraft2.png -------------------------------------------------------------------------------- /examples/placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/placeholder.png -------------------------------------------------------------------------------- /examples/ref_cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/ref_cat.jpeg -------------------------------------------------------------------------------- /examples/ref_vase.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/ref_vase.JPG -------------------------------------------------------------------------------- /examples/river1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/river1.mp4 -------------------------------------------------------------------------------- /examples/river1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/river1.png -------------------------------------------------------------------------------- /examples/river1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/river1.wav -------------------------------------------------------------------------------- /examples/river1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/river1_mask.png -------------------------------------------------------------------------------- /examples/river2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/river2.png -------------------------------------------------------------------------------- /examples/river2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/river2_mask.png -------------------------------------------------------------------------------- /examples/tank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/tank.png -------------------------------------------------------------------------------- /examples/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/truck.jpg -------------------------------------------------------------------------------- /examples/zebras1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/zebras1.jpg -------------------------------------------------------------------------------- /examples/zebras2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/Semantic-SAM/3d6a43a0f8e77167c0013d14067933a78e2d1f5a/examples/zebras2.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "Semantic-SAM" 7 | version = "1.0" 8 | description = "Segment and Recognize Anything at Any Granularity." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | 16 | dependencies = [ 17 | "torch", 18 | "torchvision", 19 | "pillow==9.4.0", 20 | "opencv-python==4.8.1.78", 21 | "pyyaml==6.0.1", 22 | "json_tricks==3.17.3", 23 | "yacs==0.1.8", 24 | "scikit-learn==1.3.1", 25 | "pandas==2.0.3", 26 | "timm==0.4.12", 27 | "numpy==1.23.1", 28 | "einops==0.7.0", 29 | "fvcore==0.1.5.post20221221", 30 | "transformers==4.34.0", 31 | "sentencepiece==0.1.99", 32 | "ftfy==6.1.1", 33 | "regex==2023.10.3", 34 | "nltk==3.8.1", 35 | "vision-datasets==0.2.2", 36 | "cython==3.0.2", 37 | "pycocotools==2.0.7", 38 | "diffdist==0.1", 39 | "pyarrow==13.0.0", 40 | "cityscapesscripts==2.2.2", 41 | "shapely==1.8.0", 42 | "scikit-image==0.21.0", 43 | "mup==1.0.0", 44 | "accelerate==0.23.0", 45 | "kornia==0.7.0", 46 | "deepspeed==0.10.3", 47 | "wandb==0.15.12", 48 | "infinibatch==0.1.1", 49 | "gradio==3.42.0", 50 | "openai-whisper", 51 | ] 52 | 53 | [tool.poetry.dependencies] 54 | detectron2 = {git = "https://github.com/MaureenZOU/detectron2-xyz.git"} 55 | 56 | 57 | [project.urls] 58 | "Paper" = "https://arxiv.org/abs/2307.04767" 59 | "Code" = "https://github.com/UX-Decoder/Semantic-SAM" 60 | "Bug Tracker" = "https://github.com/UX-Decoder/Semantic-SAM/issues" 61 | 62 | [tool.setuptools.packages.find] 63 | exclude = ["assets*"] 64 | 65 | [tool.wheel] 66 | exclude = ["assets*"] 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | opencv-python 4 | pyyaml 5 | json_tricks 6 | yacs 7 | scikit-learn 8 | pandas 9 | timm==0.4.12 10 | numpy==1.23.5 11 | einops 12 | fvcore 13 | transformers==4.19.2 14 | sentencepiece 15 | ftfy 16 | regex 17 | nltk 18 | vision-datasets==0.2.2 19 | pycocotools 20 | diffdist 21 | pyarrow 22 | cityscapesscripts 23 | shapely 24 | scikit-image 25 | mup 26 | gradio==3.35.2 27 | scann 28 | kornia==0.6.4 29 | torchmetrics==0.6.0 30 | progressbar 31 | pillow==9.4.0 32 | -------------------------------------------------------------------------------- /semantic_sam/BaseModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from utils.model import align_and_update_state_dicts 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BaseModel(nn.Module): 13 | def __init__(self, opt, module: nn.Module): 14 | super(BaseModel, self).__init__() 15 | self.opt = opt 16 | self.model = module 17 | 18 | def forward(self, *inputs, **kwargs): 19 | outputs = self.model(*inputs, **kwargs) 20 | return outputs 21 | 22 | def save_pretrained(self, save_path): 23 | torch.save(self.model.state_dict(), save_path) 24 | 25 | def from_pretrained(self, load_dir): 26 | state_dict = torch.load(load_dir, map_location='cpu') 27 | if 'model' in state_dict: 28 | state_dict=state_dict['model'] 29 | state_dict={k[6:]:v for k,v in state_dict.items() if k.startswith('model.')} 30 | # for k in self.model.state_dict(): 31 | # if k not in state_dict: 32 | # assert k[:-2] in state_dict 33 | # state_dict[k]=state_dict.pop(k[:-2]) 34 | state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) 35 | self.model.load_state_dict(state_dict, strict=False) 36 | return self -------------------------------------------------------------------------------- /semantic_sam/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .architectures import build_model 6 | from .build_semantic_sam import prepare_image, plot_results, build_semantic_sam, SemanticSamAutomaticMaskGenerator, SemanticSAMPredictor, plot_multi_results -------------------------------------------------------------------------------- /semantic_sam/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_mask_dino import * 2 | from .build import build_model -------------------------------------------------------------------------------- /semantic_sam/architectures/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | def build_model(config, **kwargs): 5 | model_name = config['MODEL']['NAME'] 6 | 7 | if not is_model(model_name): 8 | raise ValueError(f'Unkown model: {model_name}') 9 | 10 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /semantic_sam/architectures/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_model(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /semantic_sam/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone 2 | 3 | from .focal import * 4 | from .focal_dw import * 5 | from .swin import * 6 | from .backbone import * -------------------------------------------------------------------------------- /semantic_sam/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | 4 | from detectron2.modeling import ShapeSpec 5 | 6 | # from ..layers import ShapeSpec 7 | 8 | __all__ = ["Backbone"] 9 | 10 | 11 | class Backbone(nn.Module): 12 | """ 13 | Abstract base class for network backbones. 14 | """ 15 | 16 | def __init__(self): 17 | """ 18 | The `__init__` method of any subclass can specify its own set of arguments. 19 | """ 20 | super().__init__() 21 | 22 | def forward(self): 23 | """ 24 | Subclasses must override this method, but adhere to the same return type. 25 | 26 | Returns: 27 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor 28 | """ 29 | pass 30 | 31 | @property 32 | def size_divisibility(self) -> int: 33 | """ 34 | Some backbones require the input height and width to be divisible by a 35 | specific integer. This is typically true for encoder / decoder type networks 36 | with lateral connection (e.g., FPN) for which feature maps need to match 37 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific 38 | input size divisibility is required. 39 | """ 40 | return 0 41 | 42 | def output_shape(self): 43 | """ 44 | Returns: 45 | dict[str->ShapeSpec] 46 | """ 47 | # this is a backward-compatible default 48 | return { 49 | name: ShapeSpec( 50 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 51 | ) 52 | for name in self._out_features 53 | } 54 | -------------------------------------------------------------------------------- /semantic_sam/backbone/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | from .backbone import * 5 | 6 | def build_backbone(config, **kwargs): 7 | model_name = config['MODEL']['BACKBONE']['NAME'] 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /semantic_sam/backbone/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_backbone(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints 15 | -------------------------------------------------------------------------------- /semantic_sam/body/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_semantic_sam_head -------------------------------------------------------------------------------- /semantic_sam/body/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | from .general_head import * 4 | 5 | 6 | def build_semantic_sam_head(config, *args, **kwargs): 7 | model_name = config['MODEL']['HEAD'] 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | body = model_entrypoints(model_name)(config, *args, **kwargs) 12 | return body -------------------------------------------------------------------------------- /semantic_sam/body/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_decoder 2 | from .interactive_mask_dino import * -------------------------------------------------------------------------------- /semantic_sam/body/decoder/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | 5 | def build_decoder(config, *args, **kwargs): 6 | model_name = config['MODEL']['DECODER']['NAME'] 7 | 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /semantic_sam/body/decoder/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_decoder(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /semantic_sam/body/decoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /semantic_sam/body/decoder/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from torch import nn, Tensor 4 | import os 5 | 6 | import math 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class MLP(nn.Module): 12 | """ Very simple multi-layer perceptron (also called FFN)""" 13 | 14 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 15 | super().__init__() 16 | self.num_layers = num_layers 17 | h = [hidden_dim] * (num_layers - 1) 18 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 19 | 20 | def forward(self, x): 21 | for i, layer in enumerate(self.layers): 22 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 23 | return x 24 | 25 | 26 | def inverse_sigmoid(x, eps=1e-5): 27 | x = x.clamp(min=0, max=1) 28 | x1 = x.clamp(min=eps) 29 | x2 = (1 - x).clamp(min=eps) 30 | return torch.log(x1/x2) 31 | 32 | 33 | def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): 34 | """ 35 | Input: 36 | - memory: bs, \sum{hw}, d_model 37 | - memory_padding_mask: bs, \sum{hw} 38 | - spatial_shapes: nlevel, 2 39 | Output: 40 | - output_memory: bs, \sum{hw}, d_model 41 | - output_proposals: bs, \sum{hw}, 4 42 | """ 43 | N_, S_, C_ = memory.shape 44 | base_scale = 4.0 45 | proposals = [] 46 | _cur = 0 47 | for lvl, (H_, W_) in enumerate(spatial_shapes): 48 | mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) 49 | valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) 50 | valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) 51 | 52 | grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), 53 | torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) 54 | grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) 55 | 56 | scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) 57 | grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale 58 | wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) 59 | proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) 60 | proposals.append(proposal) 61 | _cur += (H_ * W_) 62 | output_proposals = torch.cat(proposals, 1) 63 | output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) 64 | output_proposals = torch.log(output_proposals / (1 - output_proposals)) 65 | output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) 66 | output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) 67 | 68 | output_memory = memory 69 | output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) 70 | output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) 71 | return output_memory, output_proposals 72 | 73 | 74 | def gen_sineembed_for_position(pos_tensor, dim=128): 75 | # n_query, bs, _ = pos_tensor.size() 76 | # sineembed_tensor = torch.zeros(n_query, bs, 256) 77 | scale = 2 * math.pi 78 | dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) 79 | dim_t = 10000 ** (2 * (dim_t // 2) / dim) 80 | x_embed = pos_tensor[:, :, 0] * scale 81 | y_embed = pos_tensor[:, :, 1] * scale 82 | pos_x = x_embed[:, :, None] / dim_t 83 | pos_y = y_embed[:, :, None] / dim_t 84 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 85 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 86 | if pos_tensor.size(-1) == 2: 87 | pos = torch.cat((pos_y, pos_x), dim=2) 88 | elif pos_tensor.size(-1) == 4: 89 | w_embed = pos_tensor[:, :, 2] * scale 90 | pos_w = w_embed[:, :, None] / dim_t 91 | pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) 92 | 93 | h_embed = pos_tensor[:, :, 3] * scale 94 | pos_h = h_embed[:, :, None] / dim_t 95 | pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) 96 | 97 | pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) 98 | else: 99 | raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) 100 | return pos 101 | 102 | 103 | def _get_activation_fn(activation): 104 | """Return an activation function given a string""" 105 | if activation == "relu": 106 | return F.relu 107 | if activation == "gelu": 108 | return F.gelu 109 | if activation == "glu": 110 | return F.glu 111 | if activation == "prelu": 112 | return nn.PReLU() 113 | if activation == "selu": 114 | return F.selu 115 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 116 | 117 | 118 | def _get_clones(module, N, layer_share=False): 119 | 120 | if layer_share: 121 | return nn.ModuleList([module for i in range(N)]) 122 | else: 123 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) -------------------------------------------------------------------------------- /semantic_sam/body/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_encoder -------------------------------------------------------------------------------- /semantic_sam/body/encoder/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | from .transformer_encoder_fpn import * 5 | from .encoder_deform import * 6 | 7 | def build_encoder(config, *args, **kwargs): 8 | model_name = config['MODEL']['ENCODER']['NAME'] 9 | 10 | if not is_model(model_name): 11 | raise ValueError(f'Unkown model: {model_name}') 12 | 13 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd mask2former/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install --user 14 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /semantic_sam/body/encoder/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_encoder(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints 14 | -------------------------------------------------------------------------------- /semantic_sam/body/general_head.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) MicroSoft, Inc. and its affiliates. 3 | # Modified from DINO https://github.com/IDEA-Research/MaskDINO by Feng Li. 4 | # ------------------------------------------------------------------------ 5 | import logging 6 | from typing import Callable, Dict, List, Optional, Tuple, Union 7 | 8 | from torch import nn 9 | 10 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 11 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 12 | 13 | from .registry import register_body 14 | from .encoder import build_encoder 15 | from .decoder import build_decoder 16 | from ..utils import configurable 17 | 18 | 19 | class IMaskDINOHead(nn.Module): 20 | @configurable 21 | def __init__( 22 | self, 23 | input_shape: Dict[str, ShapeSpec], 24 | *, 25 | num_classes: int, 26 | pixel_decoder: nn.Module, 27 | loss_weight: float = 1.0, 28 | ignore_value: int = -1, 29 | transformer_predictor: nn.Module, 30 | ): 31 | """ 32 | Args: 33 | input_shape: shapes (channels and stride) of the input features 34 | num_classes: number of classes to predict 35 | pixel_decoder: the pixel decoder module 36 | loss_weight: loss weight 37 | ignore_value: category id to be ignored during training. 38 | transformer_predictor: the transformer decoder that makes prediction 39 | transformer_in_feature: input feature name to the transformer_predictor 40 | """ 41 | super().__init__() 42 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 43 | self.in_features = [k for k, v in input_shape] 44 | self.ignore_value = ignore_value 45 | self.common_stride = 4 46 | self.loss_weight = loss_weight 47 | 48 | self.pixel_decoder = pixel_decoder 49 | self.predictor = transformer_predictor 50 | 51 | self.num_classes = num_classes 52 | # store processed features 53 | self.processed_features = None 54 | 55 | @classmethod 56 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): 57 | enc_cfg = cfg['MODEL']['ENCODER'] 58 | dec_cfg = cfg['MODEL']['DECODER'] 59 | transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] 60 | 61 | return { 62 | "input_shape": { 63 | k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] 64 | }, 65 | "ignore_value": enc_cfg['IGNORE_VALUE'], 66 | "num_classes": enc_cfg.get('NUM_CLASSES', None), 67 | "pixel_decoder": build_encoder(cfg, input_shape), 68 | "loss_weight": enc_cfg['LOSS_WEIGHT'], 69 | "transformer_predictor": build_decoder( 70 | cfg, 71 | transformer_predictor_in_channels, 72 | lang_encoder, 73 | mask_classification=True, 74 | extra=extra, 75 | ), 76 | } 77 | 78 | def forward_encoder(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): 79 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( 80 | features, mask) 81 | self.processed_features = (mask_features, transformer_encoder_features, multi_scale_features) 82 | 83 | def forward_decoder(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): 84 | assert self.processed_features is not None, "need to precess features first" 85 | mask_features, transformer_encoder_features, multi_scale_features = self.processed_features 86 | if task == 'teacher': 87 | predictions = self.predictor.forward_teacher(multi_scale_features, mask_features, mask, targets=targets, 88 | target_queries=target_queries, target_vlp=target_vlp, 89 | task=task, extra=extra) 90 | else: 91 | predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, 92 | target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) 93 | return predictions 94 | 95 | def forward(self, features, mask=None, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): 96 | return self.layers(features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) 97 | 98 | def layers(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): 99 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask) 100 | predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, 101 | target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) 102 | return predictions 103 | 104 | 105 | @register_body 106 | def get_interactive_maskdino_head(cfg, input_shape, lang_encoder, extra): 107 | return IMaskDINOHead(cfg, input_shape, lang_encoder, extra) -------------------------------------------------------------------------------- /semantic_sam/body/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_body(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /semantic_sam/build_semantic_sam.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Feng Li (fliay@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import matplotlib.pyplot as plt 9 | from PIL import Image 10 | import numpy as np 11 | from torchvision import transforms 12 | import torch 13 | import os 14 | 15 | from utils.arguments import load_opt_from_config_file 16 | from semantic_sam.BaseModel import BaseModel 17 | from semantic_sam import build_model 18 | from tasks.automatic_mask_generator import SemanticSamAutomaticMaskGenerator 19 | from tasks.interactive_idino_m2m_auto import show_anns 20 | from tasks.interactive_predictor import SemanticSAMPredictor 21 | 22 | 23 | def prepare_image(image_pth): 24 | """ 25 | apply transformation to the image. crop the image ot 640 short edge by default 26 | """ 27 | image = Image.open(image_pth).convert('RGB') 28 | t = [] 29 | t.append(transforms.Resize(640, interpolation=Image.BICUBIC)) 30 | transform1 = transforms.Compose(t) 31 | image_ori = transform1(image) 32 | 33 | image_ori = np.asarray(image_ori) 34 | images = torch.from_numpy(image_ori.copy()).permute(2, 0, 1).cuda() 35 | 36 | return image_ori, images 37 | 38 | 39 | def build_semantic_sam(model_type, ckpt): 40 | """ 41 | build model 42 | """ 43 | cfgs={'T':"configs/semantic_sam_only_sa-1b_swinT.yaml", 44 | 'L':"configs/semantic_sam_only_sa-1b_swinL.yaml"} 45 | 46 | sam_cfg=cfgs[model_type] 47 | opt = load_opt_from_config_file(sam_cfg) 48 | model_semantic_sam = BaseModel(opt, build_model(opt)).from_pretrained(ckpt).eval().cuda() 49 | return model_semantic_sam 50 | 51 | 52 | def plot_results(outputs, image_ori, save_path='../vis/'): 53 | """ 54 | plot input image and its reuslts 55 | """ 56 | if os.path.isdir(save_path): 57 | image_ori_name = 'input.png' 58 | im_name = 'example.png' 59 | else: 60 | image_ori_name = os.path.basename(save_path).split('.')[0] + '_input.png' 61 | im_name = os.path.basename(save_path).split('.')[0]+ '_example.png' 62 | save_path = os.path.dirname(save_path) 63 | 64 | if not os.path.exists(save_path): 65 | os.mkdir(save_path) 66 | 67 | fig = plt.figure() 68 | plt.imshow(image_ori) 69 | plt.savefig(os.path.join(save_path, image_ori_name)) 70 | show_anns(outputs) 71 | fig.canvas.draw() 72 | im = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) 73 | plt.savefig(os.path.join(save_path, im_name)) 74 | return im 75 | 76 | def plot_multi_results(iou_sort_masks, area_sort_masks, image_ori, save_path='../vis/'): 77 | """ 78 | plot input image and its reuslts 79 | """ 80 | if not os.path.exists(save_path): 81 | os.mkdir(save_path) 82 | plt.imshow(image_ori) 83 | plt.savefig('../vis/input.png') 84 | def create_long_image(masks): 85 | ims = [] 86 | for img in masks: 87 | ims.append(img) 88 | width, height = ims[0].size 89 | result = Image.new(ims[0].mode, (width * len(ims), height)) 90 | for i, im in enumerate(ims): 91 | result.paste(im, box=(i * width, 0)) 92 | return result 93 | create_long_image(iou_sort_masks).save('../vis/all_results_sort_by_iou.png') 94 | create_long_image(area_sort_masks).save('../vis/all_results_sort_by_areas.png') 95 | -------------------------------------------------------------------------------- /semantic_sam/language/LangEncoder/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .build import build_lang_encoder 6 | from .build import build_tokenizer 7 | 8 | from .transformer import * -------------------------------------------------------------------------------- /semantic_sam/language/LangEncoder/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import CLIPTokenizer, CLIPTokenizerFast 4 | from transformers import AutoTokenizer 5 | 6 | from .registry import lang_encoders 7 | from .registry import is_lang_encoder 8 | 9 | 10 | def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): 11 | model_name = config_encoder['NAME'] 12 | 13 | if not is_lang_encoder(model_name): 14 | raise ValueError(f'Unkown model: {model_name}') 15 | 16 | return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) 17 | 18 | 19 | def build_tokenizer(config_encoder): 20 | tokenizer = None 21 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 22 | if config_encoder['TOKENIZER'] == 'clip': 23 | pretrained_tokenizer = config_encoder.get( 24 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 25 | ) 26 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) 27 | tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) 28 | elif config_encoder['TOKENIZER'] == 'clip-fast': 29 | pretrained_tokenizer = config_encoder.get( 30 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 31 | ) 32 | tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) 33 | else: 34 | tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) 35 | 36 | return tokenizer 37 | -------------------------------------------------------------------------------- /semantic_sam/language/LangEncoder/registry.py: -------------------------------------------------------------------------------- 1 | _lang_encoders = {} 2 | 3 | 4 | def register_lang_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | 8 | _lang_encoders[model_name] = fn 9 | 10 | return fn 11 | 12 | 13 | def lang_encoders(model_name): 14 | return _lang_encoders[model_name] 15 | 16 | 17 | def is_lang_encoder(model_name): 18 | return model_name in _lang_encoders 19 | -------------------------------------------------------------------------------- /semantic_sam/language/__init__.py: -------------------------------------------------------------------------------- 1 | from .vlpencoder import * 2 | from .encoder import * 3 | from .build import build_language_encoder -------------------------------------------------------------------------------- /semantic_sam/language/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | 5 | def build_language_encoder(config, **kwargs): 6 | model_name = config['MODEL']['TEXT']['ARCH'] 7 | if model_name=='noencoder': 8 | return None 9 | 10 | if not is_model(model_name): 11 | raise ValueError(f'Unkown model: {model_name}') 12 | 13 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /semantic_sam/language/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from timm.models.layers import trunc_normal_ 6 | 7 | from .registry import register_model 8 | from ..utils import configurable 9 | from .LangEncoder import build_tokenizer, build_lang_encoder 10 | from utils.prompt_engineering import prompt_engineering, get_prompt_templates 11 | 12 | 13 | class LanguageEncoder(nn.Module): 14 | 15 | @configurable 16 | def __init__( 17 | self, 18 | tokenizer, 19 | tokenizer_type, 20 | lang_encoder, 21 | lang_projection, 22 | max_token_num, 23 | ): 24 | super().__init__() 25 | self.tokenizer = tokenizer 26 | self.tokenizer_type = tokenizer_type 27 | self.lang_encoder = lang_encoder 28 | self.lang_proj = lang_projection 29 | self.max_token_num = max_token_num 30 | self.logit_scale = nn.Parameter(torch.ones([])) 31 | 32 | @classmethod 33 | def from_config(cls, cfg): 34 | # build up text encoder 35 | tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) 36 | tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] 37 | lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) 38 | max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] 39 | 40 | dim_lang = cfg['MODEL']['TEXT']['WIDTH'] 41 | dim_projection = cfg['MODEL']['DIM_PROJ'] 42 | lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) 43 | trunc_normal_(lang_projection, std=.02) 44 | 45 | return { 46 | "tokenizer": tokenizer, 47 | "tokenizer_type": tokenizer_type, 48 | "lang_encoder": lang_encoder, 49 | "lang_projection": lang_projection, 50 | "max_token_num": max_token_num, 51 | } 52 | 53 | # @torch.no_grad() 54 | def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): 55 | if not is_eval: 56 | if prompt: 57 | # randomly sample one template 58 | arbitary_concepts = [ 59 | prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ 60 | for label in range(len(class_names)) 61 | ] 62 | if add_bgd: 63 | arbitary_concepts.append("A background in coco.") 64 | else: 65 | arbitary_concepts = class_names 66 | 67 | input_ids = [] 68 | attention_masks = [] 69 | for txt in arbitary_concepts: 70 | tokens = self.tokenizer( 71 | txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' 72 | ) 73 | tokens['input_ids'].squeeze_() 74 | tokens['attention_mask'].squeeze_() 75 | 76 | input_ids.append(tokens['input_ids']) 77 | attention_masks.append(tokens['attention_mask']) 78 | 79 | arbitary_tokens = torch.stack(input_ids) 80 | arbitary_attention_masks = torch.stack(attention_masks) 81 | 82 | text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) 83 | setattr(self, '{}_text_embeddings'.format(name), text_emb) 84 | else: 85 | with torch.no_grad(): 86 | def extract_mean_emb(txts): 87 | tokens = self.tokenizer( 88 | txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' 89 | ) 90 | clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) 91 | clss_embedding = clss_embedding.mean(dim=0) 92 | clss_embedding /= clss_embedding.norm() 93 | return clss_embedding 94 | 95 | templates = get_prompt_templates() 96 | clss_embeddings = [] 97 | for clss in class_names: 98 | txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] 99 | clss_embeddings.append(extract_mean_emb(txts)) 100 | 101 | if add_bgd: 102 | txts = ["A background in coco."] 103 | clss_embeddings.append(extract_mean_emb(txts)) 104 | 105 | text_emb = torch.stack(clss_embeddings, dim=0) 106 | setattr(self, '{}_text_embeddings'.format(name), text_emb) 107 | 108 | # @torch.no_grad() 109 | def forward_language(self, texts, norm=True): 110 | x = self.lang_encoder(*texts) 111 | x = x['last_hidden_state'] 112 | 113 | if self.tokenizer_type == 'clip': 114 | x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] 115 | else: 116 | x = x[:, 0] 117 | 118 | x = x @ self.lang_proj 119 | if norm: 120 | x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) 121 | return x 122 | 123 | def compute_similarity(self, v_emb, name='default'): 124 | v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) 125 | t_emb = getattr(self, '{}_text_embeddings'.format(name)) 126 | output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) 127 | return output 128 | 129 | 130 | @register_model 131 | def get_language_model(cfg, **kwargs): 132 | return LanguageEncoder(cfg) -------------------------------------------------------------------------------- /semantic_sam/language/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_model(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /semantic_sam/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .point_features import * 2 | from .position_encoding import * 3 | from .postprocessing import * 4 | from .attention import * 5 | from .matcher import * 6 | from .criterion_interactive_many_to_one import * 7 | from .criterion_interactive_many_to_many import * 8 | from .many2many_matcher import * 9 | # from .hooks import HookBase -------------------------------------------------------------------------------- /semantic_sam/modules/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=x.dtype) 34 | x_embed = not_mask.cumsum(2, dtype=x.dtype) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | def __repr__(self, _repr_indent=4): 55 | head = "Positional encoding " + self.__class__.__name__ 56 | body = [ 57 | "num_pos_feats: {}".format(self.num_pos_feats), 58 | "temperature: {}".format(self.temperature), 59 | "normalize: {}".format(self.normalize), 60 | "scale: {}".format(self.scale), 61 | ] 62 | # _repr_indent = 4 63 | lines = [head] + [" " * _repr_indent + line for line in body] 64 | return "\n".join(lines) 65 | -------------------------------------------------------------------------------- /semantic_sam/modules/postprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from detectron2.structures import Instances, ROIMasks 6 | 7 | 8 | # perhaps should rename to "resize_instance" 9 | def detector_postprocess( 10 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 11 | ): 12 | """ 13 | Resize the output instances. 14 | The input images are often resized when entering an object detector. 15 | As a result, we often need the outputs of the detector in a different 16 | resolution from its inputs. 17 | 18 | This function will resize the raw outputs of an R-CNN detector 19 | to produce outputs according to the desired output resolution. 20 | 21 | Args: 22 | results (Instances): the raw outputs from the detector. 23 | `results.image_size` contains the input image resolution the detector sees. 24 | This object might be modified in-place. 25 | output_height, output_width: the desired output resolution. 26 | 27 | Returns: 28 | Instances: the resized output from the model, based on the output resolution 29 | """ 30 | if isinstance(output_width, torch.Tensor): 31 | # This shape might (but not necessarily) be tensors during tracing. 32 | # Converts integer tensors to float temporaries to ensure true 33 | # division is performed when computing scale_x and scale_y. 34 | output_width_tmp = output_width.float() 35 | output_height_tmp = output_height.float() 36 | new_size = torch.stack([output_height, output_width]) 37 | else: 38 | new_size = (output_height, output_width) 39 | output_width_tmp = output_width 40 | output_height_tmp = output_height 41 | 42 | scale_x, scale_y = ( 43 | output_width_tmp / results.image_size[1], 44 | output_height_tmp / results.image_size[0], 45 | ) 46 | results = Instances(new_size, **results.get_fields()) 47 | 48 | if results.has("pred_boxes"): 49 | output_boxes = results.pred_boxes 50 | elif results.has("proposal_boxes"): 51 | output_boxes = results.proposal_boxes 52 | else: 53 | output_boxes = None 54 | assert output_boxes is not None, "Predictions must contain boxes!" 55 | 56 | output_boxes.scale(scale_x, scale_y) 57 | output_boxes.clip(results.image_size) 58 | 59 | results = results[output_boxes.nonempty()] 60 | 61 | if results.has("pred_masks"): 62 | if isinstance(results.pred_masks, ROIMasks): 63 | roi_masks = results.pred_masks 64 | else: 65 | # pred_masks is a tensor of shape (N, 1, M, M) 66 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) 67 | results.pred_masks = roi_masks.to_bitmasks( 68 | results.pred_boxes, output_height, output_width, mask_threshold 69 | ).tensor # TODO return ROIMasks/BitMask object in the future 70 | 71 | if results.has("pred_keypoints"): 72 | results.pred_keypoints[:, :, 0] *= scale_x 73 | results.pred_keypoints[:, :, 1] *= scale_y 74 | 75 | return results 76 | 77 | def bbox_postprocess(result, input_size, img_size, output_height, output_width): 78 | """ 79 | result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] 80 | """ 81 | if result is None: 82 | return None 83 | 84 | scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) 85 | result = result.sigmoid() * scale 86 | x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 87 | h,w = img_size 88 | 89 | x1 = x1.clamp(min=0, max=w) 90 | y1 = y1.clamp(min=0, max=h) 91 | x2 = x2.clamp(min=0, max=w) 92 | y2 = y2.clamp(min=0, max=h) 93 | 94 | box = torch.stack([x1,y1,x2,y2]).permute(1,0) 95 | scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) 96 | box = box*scale 97 | return box 98 | 99 | def sem_seg_postprocess(result, img_size, output_height, output_width): 100 | """ 101 | Return semantic segmentation predictions in the original resolution. 102 | 103 | The input images are often resized when entering semantic segmentor. Moreover, in same 104 | cases, they also padded inside segmentor to be divisible by maximum network stride. 105 | As a result, we often need the predictions of the segmentor in a different 106 | resolution from its inputs. 107 | 108 | Args: 109 | result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), 110 | where C is the number of classes, and H, W are the height and width of the prediction. 111 | img_size (tuple): image size that segmentor is taking as input. 112 | output_height, output_width: the desired output resolution. 113 | 114 | Returns: 115 | semantic segmentation prediction (Tensor): A tensor of the shape 116 | (C, output_height, output_width) that contains per-pixel soft predictions. 117 | """ 118 | result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) 119 | result = F.interpolate( 120 | result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True 121 | )[0] 122 | return result 123 | -------------------------------------------------------------------------------- /semantic_sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .misc import * 3 | # from .dist import * -------------------------------------------------------------------------------- /semantic_sam/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | def box_xywh_to_xyxy(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [x0, y0, (x0 + x1), (y0 + y1)] 25 | return torch.stack(b, dim=-1) 26 | 27 | 28 | # modified from torchvision to also return the union 29 | def box_iou(boxes1, boxes2): 30 | area1 = box_area(boxes1) 31 | area2 = box_area(boxes2) 32 | 33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clamp(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | iou = inter / (union+1e-6) 42 | return iou, union 43 | 44 | 45 | def generalized_box_iou(boxes1, boxes2): 46 | """ 47 | Generalized IoU from https://giou.stanford.edu/ 48 | 49 | The boxes should be in [x0, y0, x1, y1] format 50 | 51 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 52 | and M = len(boxes2) 53 | """ 54 | # degenerate boxes gives inf / nan results 55 | # so do an early check 56 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 57 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 58 | iou, union = box_iou(boxes1, boxes2) 59 | 60 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 61 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 62 | 63 | wh = (rb - lt).clamp(min=0) # [N,M,2] 64 | area = wh[:, :, 0] * wh[:, :, 1] 65 | 66 | return iou - (area - union) / (area+1e-6) 67 | 68 | def generalized_box_iou_padded(boxes1, boxes2): 69 | """ 70 | Generalized IoU from https://giou.stanford.edu/ 71 | 72 | The boxes should be in [x0, y0, x1, y1] format 73 | 74 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 75 | and M = len(boxes2) 76 | """ 77 | # degenerate boxes gives inf / nan results 78 | # so do an early check 79 | # assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 80 | # assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 81 | iou, union = box_iou(boxes1, boxes2) 82 | 83 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 84 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 85 | 86 | wh = (rb - lt).clamp(min=0) # [N,M,2] 87 | area = wh[:, :, 0] * wh[:, :, 1] 88 | 89 | return iou - (area - union) / (area+1e-6) 90 | 91 | 92 | def masks_to_boxes(masks): 93 | """Compute the bounding boxes around the provided masks 94 | 95 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 96 | 97 | Returns a [N, 4] tensors, with the boxes in xyxy format 98 | """ 99 | if masks.numel() == 0: 100 | return torch.zeros((0, 4), device=masks.device) 101 | 102 | h, w = masks.shape[-2:] 103 | 104 | y = torch.arange(0, h, dtype=torch.float) 105 | x = torch.arange(0, w, dtype=torch.float) 106 | y, x = torch.meshgrid(y, x) 107 | 108 | x_mask = (masks * x.unsqueeze(0)) 109 | x_max = x_mask.flatten(1).max(-1)[0] 110 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 111 | 112 | y_mask = (masks * y.unsqueeze(0)) 113 | y_max = y_mask.flatten(1).max(-1)[0] 114 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 115 | 116 | return torch.stack([x_min, y_min, x_max, y_max], 1) -------------------------------------------------------------------------------- /semantic_sam/utils/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import functools 5 | import inspect 6 | 7 | def configurable(init_func=None, *, from_config=None): 8 | """ 9 | Decorate a function or a class's __init__ method so that it can be called 10 | with a :class:`CfgNode` object using a :func:`from_config` function that translates 11 | :class:`CfgNode` to arguments. 12 | 13 | Examples: 14 | :: 15 | # Usage 1: Decorator on __init__: 16 | class A: 17 | @configurable 18 | def __init__(self, a, b=2, c=3): 19 | pass 20 | 21 | @classmethod 22 | def from_config(cls, cfg): # 'cfg' must be the first argument 23 | # Returns kwargs to be passed to __init__ 24 | return {"a": cfg.A, "b": cfg.B} 25 | 26 | a1 = A(a=1, b=2) # regular construction 27 | a2 = A(cfg) # construct with a cfg 28 | a3 = A(cfg, b=3, c=4) # construct with extra overwrite 29 | 30 | # Usage 2: Decorator on any function. Needs an extra from_config argument: 31 | @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) 32 | def a_func(a, b=2, c=3): 33 | pass 34 | 35 | a1 = a_func(a=1, b=2) # regular call 36 | a2 = a_func(cfg) # call with a cfg 37 | a3 = a_func(cfg, b=3, c=4) # call with extra overwrite 38 | 39 | Args: 40 | init_func (callable): a class's ``__init__`` method in usage 1. The 41 | class must have a ``from_config`` classmethod which takes `cfg` as 42 | the first argument. 43 | from_config (callable): the from_config function in usage 2. It must take `cfg` 44 | as its first argument. 45 | """ 46 | 47 | if init_func is not None: 48 | assert ( 49 | inspect.isfunction(init_func) 50 | and from_config is None 51 | and init_func.__name__ == "__init__" 52 | ), "Incorrect use of @configurable. Check API documentation for examples." 53 | 54 | @functools.wraps(init_func) 55 | def wrapped(self, *args, **kwargs): 56 | try: 57 | from_config_func = type(self).from_config 58 | except AttributeError as e: 59 | raise AttributeError( 60 | "Class with @configurable must have a 'from_config' classmethod." 61 | ) from e 62 | if not inspect.ismethod(from_config_func): 63 | raise TypeError("Class with @configurable must have a 'from_config' classmethod.") 64 | 65 | # import ipdb; ipdb.set_trace() 66 | if _called_with_cfg(*args, **kwargs): 67 | explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) 68 | init_func(self, **explicit_args) 69 | else: 70 | init_func(self, *args, **kwargs) 71 | 72 | return wrapped 73 | 74 | else: 75 | if from_config is None: 76 | return configurable # @configurable() is made equivalent to @configurable 77 | assert inspect.isfunction( 78 | from_config 79 | ), "from_config argument of configurable must be a function!" 80 | 81 | def wrapper(orig_func): 82 | @functools.wraps(orig_func) 83 | def wrapped(*args, **kwargs): 84 | if _called_with_cfg(*args, **kwargs): 85 | explicit_args = _get_args_from_config(from_config, *args, **kwargs) 86 | return orig_func(**explicit_args) 87 | else: 88 | return orig_func(*args, **kwargs) 89 | 90 | wrapped.from_config = from_config 91 | return wrapped 92 | 93 | return wrapper 94 | 95 | def _called_with_cfg(*args, **kwargs): 96 | """ 97 | Returns: 98 | bool: whether the arguments contain CfgNode and should be considered 99 | forwarded to from_config. 100 | """ 101 | from omegaconf import DictConfig, OmegaConf, ListConfig 102 | # from detectron2.config import LazyConfig 103 | 104 | if len(args) and (isinstance(args[0], (dict)) or (isinstance(args[0], (DictConfig)))): 105 | return True 106 | if isinstance(kwargs.pop("cfg", None), (dict)): 107 | return True 108 | # `from_config`'s first argument is forced to be "cfg". 109 | # So the above check covers all cases. 110 | return False 111 | 112 | def _get_args_from_config(from_config_func, *args, **kwargs): 113 | """ 114 | Use `from_config` to obtain explicit arguments. 115 | 116 | Returns: 117 | dict: arguments to be used for cls.__init__ 118 | """ 119 | signature = inspect.signature(from_config_func) 120 | if list(signature.parameters.keys())[0] != "cfg": 121 | if inspect.isfunction(from_config_func): 122 | name = from_config_func.__name__ 123 | else: 124 | name = f"{from_config_func.__self__}.from_config" 125 | raise TypeError(f"{name} must take 'cfg' as the first argument!") 126 | support_var_arg = any( 127 | param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] 128 | for param in signature.parameters.values() 129 | ) 130 | if support_var_arg: # forward all arguments to from_config, if from_config accepts them 131 | ret = from_config_func(*args, **kwargs) 132 | else: 133 | # forward supported arguments to from_config 134 | supported_arg_names = set(signature.parameters.keys()) 135 | extra_kwargs = {} 136 | for name in list(kwargs.keys()): 137 | if name not in supported_arg_names: 138 | extra_kwargs[name] = kwargs.pop(name) 139 | ret = from_config_func(*args, **kwargs) 140 | # forward the other arguments to __init__ 141 | ret.update(extra_kwargs) 142 | return ret -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m 2 | from .interactive_idino_m2m_auto import interactive_infer_image as interactive_infer_image_idino_m2m_auto 3 | from .automatic_mask_generator import prompt_switch 4 | from .interactive_predictor import SemanticSAMPredictor -------------------------------------------------------------------------------- /tasks/interactive_idino_m2m.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Hao Zhang (hzhangcx@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import numpy as np 10 | from torchvision import transforms 11 | from utils.visualizer import Visualizer 12 | from typing import Tuple 13 | from PIL import Image 14 | from detectron2.data import MetadataCatalog 15 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 16 | 17 | def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None): 18 | t = [] 19 | t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) 20 | transform1 = transforms.Compose(t) 21 | image_ori = transform1(image['image']) 22 | mask_ori = transform1(image['mask']) 23 | width = image_ori.size[0] 24 | height = image_ori.size[1] 25 | image_ori = np.asarray(image_ori) 26 | images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() 27 | all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':') 28 | 29 | 30 | data = {"image": images, "height": height, "width": width} 31 | 32 | mask_ori = np.asarray(mask_ori)[:,:,0:1].copy() 33 | mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0] 34 | points=mask_ori.nonzero().float().to(images.device) 35 | if len(points)==0: 36 | point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]]) 37 | else: 38 | point_=points.mean(0)[None] 39 | point=point_.clone() 40 | point[0, 0] = point_[0, 0] / mask_ori.shape[0] 41 | point[0, 1] = point_[0, 1] / mask_ori.shape[1] 42 | point = point[:, [1, 0]] 43 | point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1) 44 | data['targets'] = [dict()] 45 | data['targets'][0]['points']=point 46 | data['targets'][0]['pb']=point.new_tensor([0.]) 47 | 48 | 49 | batch_inputs = [data] 50 | masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts) 51 | 52 | pred_masks_poses = masks 53 | reses=[] 54 | ious=ious[0,0] 55 | ids=torch.argsort(ious,descending=True) 56 | 57 | text_res='' 58 | try: 59 | thresh=float(thresh) 60 | except Exception: 61 | thresh=0.0 62 | mask_ls=[] 63 | ious_res=[] 64 | areas=[] 65 | for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])): 66 | iou=round(float(iou),2) 67 | texts=f'{iou}' 68 | mask=(pred_masks_pos>0.0).cpu().numpy() 69 | area=mask.sum() 70 | conti=False 71 | if iou0.95: 75 | conti=True 76 | break 77 | if i == len(pred_masks_poses[ids])-1 and mask_ls==[]: 78 | conti=False 79 | if conti: 80 | continue 81 | ious_res.append(iou) 82 | mask_ls.append(mask) 83 | areas.append(area) 84 | mask,_=remove_small_regions(mask,int(hole_scale),mode="holes") 85 | mask,_=remove_small_regions(mask,int(island_scale),mode="islands") 86 | mask=(mask).astype(np.float) 87 | out_txt = texts 88 | visual = Visualizer(image_ori, metadata=metadata) 89 | color=[0.,0.,1.0] 90 | demo = visual.draw_binary_mask(mask, color=color, text=texts) 91 | res = demo.get_image() 92 | point_x0=max(0,int(point_[0, 1])-3) 93 | point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3) 94 | point_y0 = max(0, int(point_[0, 0]) - 3) 95 | point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3) 96 | res[point_y0:point_y1,point_x0:point_x1,0]=255 97 | res[point_y0:point_y1,point_x0:point_x1,1]=0 98 | res[point_y0:point_y1,point_x0:point_x1,2]=0 99 | reses.append(Image.fromarray(res)) 100 | text_res=text_res+';'+out_txt 101 | ids=list(torch.argsort(torch.tensor(areas),descending=False)) 102 | ids = [int(i) for i in ids] 103 | 104 | torch.cuda.empty_cache() 105 | 106 | return reses,[reses[i] for i in ids] 107 | 108 | def remove_small_regions( 109 | mask: np.ndarray, area_thresh: float, mode: str 110 | ) -> Tuple[np.ndarray, bool]: 111 | """ 112 | Removes small disconnected regions and holes in a mask. Returns the 113 | mask and an indicator of if the mask has been modified. 114 | """ 115 | import cv2 # type: ignore 116 | 117 | assert mode in ["holes", "islands"] 118 | correct_holes = mode == "holes" 119 | working_mask = (correct_holes ^ mask).astype(np.uint8) 120 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 121 | sizes = stats[:, -1][1:] # Row 0 is background label 122 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 123 | if len(small_regions) == 0: 124 | return mask, False 125 | fill_labels = [0] + small_regions 126 | if not correct_holes: 127 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 128 | # If every region is below threshold, keep largest 129 | if len(fill_labels) == 0: 130 | fill_labels = [int(np.argmax(sizes)) + 1] 131 | mask = np.isin(regions, fill_labels) 132 | return mask, True -------------------------------------------------------------------------------- /tasks/interactive_idino_m2m_auto.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Hao Zhang (hzhangcx@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import numpy as np 10 | from torchvision import transforms 11 | from utils.visualizer import Visualizer 12 | from typing import Tuple 13 | from PIL import Image 14 | from detectron2.data import MetadataCatalog 15 | import matplotlib.pyplot as plt 16 | import cv2 17 | import io 18 | from .automatic_mask_generator import SemanticSamAutomaticMaskGenerator 19 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 20 | 21 | def interactive_infer_image(model, image,level,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None): 22 | t = [] 23 | t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) 24 | transform1 = transforms.Compose(t) 25 | image_ori = transform1(image) 26 | 27 | image_ori = np.asarray(image_ori) 28 | images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() 29 | 30 | mask_generator = SemanticSamAutomaticMaskGenerator(model,points_per_side=32, 31 | pred_iou_thresh=0.88, 32 | stability_score_thresh=0.92, 33 | min_mask_region_area=10, 34 | level=level, 35 | ) 36 | 37 | outputs = mask_generator.generate(images) 38 | 39 | fig=plt.figure(figsize=(10, 10)) 40 | plt.imshow(image_ori) 41 | show_anns(outputs) 42 | fig.canvas.draw() 43 | im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) 44 | return im 45 | 46 | 47 | def remove_small_regions( 48 | mask: np.ndarray, area_thresh: float, mode: str 49 | ) -> Tuple[np.ndarray, bool]: 50 | """ 51 | Removes small disconnected regions and holes in a mask. Returns the 52 | mask and an indicator of if the mask has been modified. 53 | """ 54 | import cv2 # type: ignore 55 | 56 | assert mode in ["holes", "islands"] 57 | correct_holes = mode == "holes" 58 | working_mask = (correct_holes ^ mask).astype(np.uint8) 59 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 60 | sizes = stats[:, -1][1:] # Row 0 is background label 61 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 62 | if len(small_regions) == 0: 63 | return mask, False 64 | fill_labels = [0] + small_regions 65 | if not correct_holes: 66 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 67 | # If every region is below threshold, keep largest 68 | if len(fill_labels) == 0: 69 | fill_labels = [int(np.argmax(sizes)) + 1] 70 | mask = np.isin(regions, fill_labels) 71 | return mask, True 72 | 73 | def show_anns(anns): 74 | if len(anns) == 0: 75 | return 76 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 77 | ax = plt.gca() 78 | ax.set_autoscale_on(False) 79 | polygons = [] 80 | color = [] 81 | for ann in sorted_anns: 82 | m = ann['segmentation'] 83 | img = np.ones((m.shape[0], m.shape[1], 3)) 84 | color_mask = np.random.random((1, 3)).tolist()[0] 85 | for i in range(3): 86 | img[:,:,i] = color_mask[i] 87 | ax.imshow(np.dstack((img, m*0.35))) -------------------------------------------------------------------------------- /tasks/interactive_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision import transforms 4 | from utils.visualizer import Visualizer 5 | from typing import Tuple 6 | from PIL import Image 7 | from detectron2.data import MetadataCatalog 8 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 9 | 10 | 11 | class SemanticSAMPredictor: 12 | def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100): 13 | """ 14 | thresh: iou thresh to filter low confidence objects 15 | text_size: resize the input image short edge for the model to process 16 | hole_scale: fill in small holes as in SAM 17 | island_scale: remove small regions as in SAM 18 | """ 19 | self.model = model 20 | self.thresh = thresh 21 | self.text_size = hole_scale 22 | self.hole_scale = hole_scale 23 | self.island_scale = island_scale 24 | self.point = None 25 | 26 | def predict(self, image_ori, image, point=None): 27 | """ 28 | produce up to 6 prediction results for each click 29 | """ 30 | width = image_ori.shape[1] 31 | height = image_ori.shape[0] 32 | 33 | data = {"image": image, "height": height, "width": width} 34 | # import ipdb; ipdb.set_trace() 35 | if point is None: 36 | point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda() 37 | else: 38 | point = torch.tensor(point).cuda() 39 | point_ = point 40 | point = point_.clone() 41 | point[0, 0] = point_[0, 0] 42 | point[0, 1] = point_[0, 1] 43 | # point = point[:, [1, 0]] 44 | point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1) 45 | 46 | self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point)) 47 | 48 | data['targets'] = [dict()] 49 | data['targets'][0]['points'] = point 50 | data['targets'][0]['pb'] = point.new_tensor([0.]) 51 | 52 | batch_inputs = [data] 53 | masks, ious = self.model.model.evaluate_demo(batch_inputs) 54 | 55 | return masks, ious 56 | 57 | def process_multi_mask(self, masks, ious, image_ori): 58 | pred_masks_poses = masks 59 | reses = [] 60 | ious = ious[0, 0] 61 | ids = torch.argsort(ious, descending=True) 62 | 63 | text_res = '' 64 | mask_ls = [] 65 | ious_res = [] 66 | areas = [] 67 | for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])): 68 | iou = round(float(iou), 2) 69 | texts = f'{iou}' 70 | mask = (pred_masks_pos > 0.0).cpu().numpy() 71 | area = mask.sum() 72 | conti = False 73 | if iou < self.thresh: 74 | conti = True 75 | for m in mask_ls: 76 | if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95: 77 | conti = True 78 | break 79 | if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []: 80 | conti = False 81 | if conti: 82 | continue 83 | ious_res.append(iou) 84 | mask_ls.append(mask) 85 | areas.append(area) 86 | mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes") 87 | mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands") 88 | mask = (mask).astype(np.float) 89 | out_txt = texts 90 | visual = Visualizer(image_ori, metadata=metadata) 91 | color = [0., 0., 1.0] 92 | demo = visual.draw_binary_mask(mask, color=color, text=texts) 93 | res = demo.get_image() 94 | point_x0 = max(0, int(self.point[0, 0]) - 3) 95 | point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3) 96 | point_y0 = max(0, int(self.point[0, 1]) - 3) 97 | point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3) 98 | res[point_y0:point_y1, point_x0:point_x1, 0] = 255 99 | res[point_y0:point_y1, point_x0:point_x1, 1] = 0 100 | res[point_y0:point_y1, point_x0:point_x1, 2] = 0 101 | reses.append(Image.fromarray(res)) 102 | text_res = text_res + ';' + out_txt 103 | ids = list(torch.argsort(torch.tensor(areas), descending=False)) 104 | ids = [int(i) for i in ids] 105 | 106 | torch.cuda.empty_cache() 107 | 108 | return reses, [reses[i] for i in ids] 109 | 110 | def predict_masks(self, image_ori, image, point=None): 111 | masks, ious = self.predict(image_ori, image, point) 112 | return self.process_multi_mask(masks, ious, image_ori) 113 | 114 | @staticmethod 115 | def remove_small_regions( 116 | mask: np.ndarray, area_thresh: float, mode: str 117 | ) -> Tuple[np.ndarray, bool]: 118 | """ 119 | Removes small disconnected regions and holes in a mask. Returns the 120 | mask and an indicator of if the mask has been modified. 121 | """ 122 | import cv2 # type: ignore 123 | 124 | assert mode in ["holes", "islands"] 125 | correct_holes = mode == "holes" 126 | working_mask = (correct_holes ^ mask).astype(np.uint8) 127 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 128 | sizes = stats[:, -1][1:] # Row 0 is background label 129 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 130 | if len(small_regions) == 0: 131 | return mask, False 132 | fill_labels = [0] + small_regions 133 | if not correct_holes: 134 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 135 | # If every region is below threshold, keep largest 136 | if len(fill_labels) == 0: 137 | fill_labels = [int(np.argmax(sizes)) + 1] 138 | mask = np.isin(regions, fill_labels) 139 | return mask, True 140 | -------------------------------------------------------------------------------- /utils/Config.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode as _CfgNode 2 | 3 | class CfgNode(_CfgNode): 4 | """ 5 | The same as `fvcore.common.config.CfgNode`, but different in: 6 | 7 | 1. Use unsafe yaml loading by default. 8 | Note that this may lead to arbitrary code execution: you must not 9 | load a config file from untrusted sources before manually inspecting 10 | the content of the file. 11 | 2. Support config versioning. 12 | When attempting to merge an old config, it will convert the old config automatically. 13 | 14 | .. automethod:: clone 15 | .. automethod:: freeze 16 | .. automethod:: defrost 17 | .. automethod:: is_frozen 18 | .. automethod:: load_yaml_with_base 19 | .. automethod:: merge_from_list 20 | .. automethod:: merge_from_other_cfg 21 | """ 22 | 23 | def merge_from_dict(self, dict): 24 | pass 25 | 26 | node = CfgNode() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist import * -------------------------------------------------------------------------------- /utils/arguments.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import argparse 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def load_config_dict_to_opt(opt, config_dict): 10 | """ 11 | Load the key, value pairs from config_dict to opt, overriding existing values in opt 12 | if there is any. 13 | """ 14 | if not isinstance(config_dict, dict): 15 | raise TypeError("Config must be a Python dictionary") 16 | for k, v in config_dict.items(): 17 | k_parts = k.split('.') 18 | pointer = opt 19 | for k_part in k_parts[:-1]: 20 | if k_part not in pointer: 21 | pointer[k_part] = {} 22 | pointer = pointer[k_part] 23 | assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." 24 | ori_value = pointer.get(k_parts[-1]) 25 | pointer[k_parts[-1]] = v 26 | if ori_value: 27 | logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") 28 | 29 | def load_opt_from_config_file(conf_file): 30 | """ 31 | Load opt from the config files, settings in later files can override those in previous files. 32 | 33 | Args: 34 | conf_files: config file path 35 | 36 | Returns: 37 | dict: a dictionary of opt settings 38 | """ 39 | opt = {} 40 | with open(conf_file, encoding='utf-8') as f: 41 | config_dict = yaml.safe_load(f) 42 | 43 | load_config_dict_to_opt(opt, config_dict) 44 | 45 | return opt 46 | 47 | def load_opt_from_config_files(conf_files): 48 | """ 49 | Load opt from the config files, settings in later files can override those in previous files. 50 | 51 | Args: 52 | conf_files (list): a list of config file paths 53 | 54 | Returns: 55 | dict: a dictionary of opt settings 56 | """ 57 | opt = {} 58 | for conf_file in conf_files: 59 | with open(conf_file, encoding='utf-8') as f: 60 | config_dict = yaml.safe_load(f) 61 | 62 | load_config_dict_to_opt(opt, config_dict) 63 | 64 | return opt 65 | 66 | 67 | def load_opt_command(args): 68 | parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.') 69 | parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') 70 | parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).') 71 | parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.') 72 | parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') 73 | parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER) 74 | 75 | cmdline_args = parser.parse_args() if not args else parser.parse_args(args) 76 | 77 | opt = load_opt_from_config_files(cmdline_args.conf_files) 78 | 79 | if cmdline_args.config_overrides: 80 | config_overrides_string = ' '.join(cmdline_args.config_overrides) 81 | logger.warning(f"Command line config overrides: {config_overrides_string}") 82 | config_dict = json.loads(config_overrides_string) 83 | load_config_dict_to_opt(opt, config_dict) 84 | 85 | if cmdline_args.overrides: 86 | assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value" 87 | keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] 88 | vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] 89 | vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] 90 | 91 | types = [] 92 | for key in keys: 93 | key = key.split('.') 94 | ele = opt.copy() 95 | while len(key) > 0: 96 | ele = ele[key.pop(0)] 97 | types.append(type(ele)) 98 | 99 | config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} 100 | load_config_dict_to_opt(opt, config_dict) 101 | 102 | # combine cmdline_args into opt dictionary 103 | for key, val in cmdline_args.__dict__.items(): 104 | if val is not None: 105 | opt[key] = val 106 | 107 | return opt, cmdline_args 108 | 109 | 110 | def save_opt_to_json(opt, conf_file): 111 | with open(conf_file, 'w', encoding='utf-8') as f: 112 | json.dump(opt, f, indent=4) 113 | 114 | 115 | def save_opt_to_yaml(opt, conf_file): 116 | with open(conf_file, 'w', encoding='utf-8') as f: 117 | yaml.dump(opt, f) 118 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import os 4 | import random 5 | import subprocess 6 | import time 7 | from collections import OrderedDict, defaultdict, deque 8 | import datetime 9 | import pickle 10 | from typing import Optional, List 11 | 12 | import json, time 13 | import numpy as np 14 | import torch 15 | import torch.distributed as dist 16 | from torch import Tensor 17 | 18 | import colorsys 19 | def init_distributed_mode(args): 20 | if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and 21 | args.rank = int(os.environ["RANK"]) 22 | args.world_size = int(os.environ['WORLD_SIZE']) 23 | args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) 24 | 25 | # launch by torch.distributed.launch 26 | # Single node 27 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... 28 | # Multi nodes 29 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... 30 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... 31 | # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) 32 | # local_world_size = int(os.environ['GPU_PER_NODE_COUNT']) 33 | # args.world_size = args.world_size * local_world_size 34 | # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) 35 | # args.rank = args.rank * local_world_size + args.local_rank 36 | print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) 37 | print(json.dumps(dict(os.environ), indent=2)) 38 | elif 'SLURM_PROCID' in os.environ: 39 | args.rank = int(os.environ['SLURM_PROCID']) 40 | args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) 41 | args.world_size = int(os.environ['SLURM_NPROCS']) 42 | 43 | if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1': 44 | pass 45 | else: 46 | import util.hostlist as uh 47 | nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST']) 48 | gpu_ids = [int(node[3:]) for node in nodenames] 49 | fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0)) 50 | # fixid += random.randint(0, 300) 51 | port = str(3137 + int(min(gpu_ids)) + fixid) 52 | args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port) 53 | 54 | print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count())) 55 | 56 | 57 | else: 58 | print('Not using distributed mode') 59 | args.distributed = False 60 | args.world_size = 1 61 | args.rank = 0 62 | args.local_rank = 0 63 | return 64 | 65 | print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) 66 | args.distributed = True 67 | torch.cuda.set_device(args.local_rank) 68 | args.dist_backend = 'nccl' 69 | print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) 70 | 71 | torch.distributed.init_process_group( 72 | backend=args.dist_backend, 73 | world_size=args.world_size, 74 | rank=args.rank, 75 | init_method=args.dist_url, 76 | ) 77 | 78 | print("Before torch.distributed.barrier()") 79 | torch.distributed.barrier() 80 | print("End torch.distributed.barrier()") -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import time 3 | # import torch 4 | # import pickle 5 | # import subprocess 6 | 7 | # from mpi4py import MPI 8 | # import torch.distributed as dist 9 | 10 | 11 | # def apply_distributed(opt): 12 | # if opt['rank'] == 0: 13 | # hostname_cmd = ["hostname -I"] 14 | # result = subprocess.check_output(hostname_cmd, shell=True) 15 | # master_address = result.decode('utf-8').split()[0] 16 | # master_port = opt['PORT'] 17 | # else: 18 | # master_address = None 19 | # master_port = None 20 | 21 | # master_address = MPI.COMM_WORLD.bcast(master_address, root=0) 22 | # master_port = MPI.COMM_WORLD.bcast(master_port, root=0) 23 | 24 | # if torch.distributed.is_available() and opt['world_size'] > 1: 25 | # init_method_url = 'tcp://{}:{}'.format(master_address, master_port) 26 | # backend = 'nccl' 27 | # world_size = opt['world_size'] 28 | # rank = opt['rank'] 29 | # torch.distributed.init_process_group(backend=backend, 30 | # init_method=init_method_url, 31 | # world_size=world_size, 32 | # rank=rank) 33 | 34 | # def init_distributed(opt): 35 | # opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() 36 | # if 'OMPI_COMM_WORLD_SIZE' not in os.environ: 37 | # # application was started without MPI 38 | # # default to single node with single process 39 | # opt['env_info'] = 'no MPI' 40 | # opt['world_size'] = 1 41 | # opt['local_size'] = 1 42 | # opt['rank'] = 0 43 | # opt['local_rank'] = 0 44 | # opt['master_address'] = '127.0.0.1' 45 | # opt['master_port'] = '8673' 46 | # else: 47 | # # application was started with MPI 48 | # # get MPI parameters 49 | # opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) 50 | # opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) 51 | # opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) 52 | # opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 53 | 54 | # # set up device 55 | # if not opt['CUDA']: 56 | # assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend' 57 | # opt['device'] = torch.device("cpu") 58 | # else: 59 | # torch.cuda.set_device(opt['local_rank']) 60 | # opt['device'] = torch.device("cuda", opt['local_rank']) 61 | 62 | # apply_distributed(opt) 63 | # return opt 64 | 65 | # def is_main_process(): 66 | # rank = 0 67 | # if 'OMPI_COMM_WORLD_SIZE' in os.environ: 68 | # rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 69 | 70 | # return rank == 0 71 | 72 | # def get_world_size(): 73 | # if not dist.is_available(): 74 | # return 1 75 | # if not dist.is_initialized(): 76 | # return 1 77 | # return dist.get_world_size() 78 | 79 | # def get_rank(): 80 | # if not dist.is_available(): 81 | # return 0 82 | # if not dist.is_initialized(): 83 | # return 0 84 | # return dist.get_rank() 85 | 86 | 87 | # def synchronize(): 88 | # """ 89 | # Helper function to synchronize (barrier) among all processes when 90 | # using distributed training 91 | # """ 92 | # if not dist.is_available(): 93 | # return 94 | # if not dist.is_initialized(): 95 | # return 96 | # world_size = dist.get_world_size() 97 | # rank = dist.get_rank() 98 | # if world_size == 1: 99 | # return 100 | 101 | # def _send_and_wait(r): 102 | # if rank == r: 103 | # tensor = torch.tensor(0, device="cuda") 104 | # else: 105 | # tensor = torch.tensor(1, device="cuda") 106 | # dist.broadcast(tensor, r) 107 | # while tensor.item() == 1: 108 | # time.sleep(1) 109 | 110 | # _send_and_wait(0) 111 | # # now sync on the main process 112 | # _send_and_wait(1) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import math 8 | 9 | 10 | # HACK for evalution 11 | def hook_metadata(metadata, name): 12 | if name == 'cityscapes_fine_sem_seg_val': 13 | metadata.__setattr__("keep_sem_bgd", False) 14 | return metadata 15 | 16 | def hook_opt(model, name): 17 | if name in ['cityscapes_fine_panoptic_val', 'ade20k_panoptic_val', 'bdd10k_40_panoptic_val', 'cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val']: 18 | model.model.object_mask_threshold = 0.4 19 | else: 20 | model.model.object_mask_threshold = 0.8 21 | 22 | # HACK for evalution 23 | def hook_switcher(model, name): 24 | mappings = {} 25 | if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']: 26 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False} 27 | elif name in ['cityscapes_fine_instance_seg_val', 'pascal_part_val_interactive', 'pascal_part_val', 'pascal_part_train'] or 'seginw' in name: 28 | mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False} 29 | elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']: 30 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True} 31 | elif 'coco_2017_val_panoptic_with_sem_seg' in name or name in ['ade20k_panoptic_val', 'coco_2017_test-dev', 'sam_val', 'sam_minival']: 32 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True} 33 | else: 34 | if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd"]: 35 | assert False, "dataset switcher is not defined" 36 | for key, value in mappings.items(): 37 | if key == 'SEMANTIC_ON': 38 | model.model.semantic_on = value 39 | if key == 'INSTANCE_ON': 40 | model.model.instance_on = value 41 | if key == 'PANOPTIC_ON': 42 | model.model.panoptic_on = value 43 | 44 | class AverageMeter(object): 45 | """Computes and stores the average and current value.""" 46 | def __init__(self): 47 | self.reset() 48 | 49 | def reset(self): 50 | self.val = 0 51 | self.avg = 0 52 | self.sum = 0 53 | self.count = 0 54 | 55 | def update(self, val, n=1, decay=0): 56 | self.val = val 57 | if decay: 58 | alpha = math.exp(-n / decay) # exponential decay over 100 updates 59 | self.sum = alpha * self.sum + (1 - alpha) * val * n 60 | self.count = alpha * self.count + (1 - alpha) * n 61 | else: 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import pickle 5 | import torch 6 | from detectron2.utils.comm import is_main_process 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | NORM_MODULES = [ 12 | torch.nn.BatchNorm1d, 13 | torch.nn.BatchNorm2d, 14 | torch.nn.BatchNorm3d, 15 | torch.nn.SyncBatchNorm, 16 | # NaiveSyncBatchNorm inherits from BatchNorm2d 17 | torch.nn.GroupNorm, 18 | torch.nn.InstanceNorm1d, 19 | torch.nn.InstanceNorm2d, 20 | torch.nn.InstanceNorm3d, 21 | torch.nn.LayerNorm, 22 | torch.nn.LocalResponseNorm, 23 | ] 24 | 25 | def register_norm_module(cls): 26 | NORM_MODULES.append(cls) 27 | return cls 28 | 29 | def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): 30 | model_keys = sorted(model_state_dict.keys()) 31 | ckpt_keys = sorted(ckpt_state_dict.keys()) 32 | result_dicts = {} 33 | matched_log = [] 34 | unmatched_log = [] 35 | unloaded_log = [] 36 | for model_key in model_keys: 37 | model_weight = model_state_dict[model_key] 38 | if model_key in ckpt_keys: 39 | ckpt_weight = ckpt_state_dict[model_key] 40 | if model_weight.shape == ckpt_weight.shape: 41 | result_dicts[model_key] = ckpt_weight 42 | ckpt_keys.pop(ckpt_keys.index(model_key)) 43 | matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 44 | else: 45 | unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 46 | else: 47 | unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) 48 | 49 | if is_main_process(): 50 | for info in matched_log: 51 | logger.info(info) 52 | for info in unloaded_log: 53 | logger.warning(info) 54 | for key in ckpt_keys: 55 | logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) 56 | for info in unmatched_log: 57 | logger.warning(info) 58 | return result_dicts -------------------------------------------------------------------------------- /utils/prompt_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_prompt_templates(): 5 | prompt_templates = [ 6 | '{}.', 7 | 'a photo of a {}.', 8 | 'a bad photo of a {}.', 9 | 'a photo of many {}.', 10 | 'a sculpture of a {}.', 11 | 'a photo of the hard to see {}.', 12 | 'a low resolution photo of the {}.', 13 | 'a rendering of a {}.', 14 | 'graffiti of a {}.', 15 | 'a bad photo of the {}.', 16 | 'a cropped photo of the {}.', 17 | 'a tattoo of a {}.', 18 | 'the embroidered {}.', 19 | 'a photo of a hard to see {}.', 20 | 'a bright photo of a {}.', 21 | 'a photo of a clean {}.', 22 | 'a photo of a dirty {}.', 23 | 'a dark photo of the {}.', 24 | 'a drawing of a {}.', 25 | 'a photo of my {}.', 26 | 'the plastic {}.', 27 | 'a photo of the cool {}.', 28 | 'a close-up photo of a {}.', 29 | 'a black and white photo of the {}.', 30 | 'a painting of the {}.', 31 | 'a painting of a {}.', 32 | 'a pixelated photo of the {}.', 33 | 'a sculpture of the {}.', 34 | 'a bright photo of the {}.', 35 | 'a cropped photo of a {}.', 36 | 'a plastic {}.', 37 | 'a photo of the dirty {}.', 38 | 'a jpeg corrupted photo of a {}.', 39 | 'a blurry photo of the {}.', 40 | 'a photo of the {}.', 41 | 'a good photo of the {}.', 42 | 'a rendering of the {}.', 43 | 'a {} in a video game.', 44 | 'a photo of one {}.', 45 | 'a doodle of a {}.', 46 | 'a close-up photo of the {}.', 47 | 'the origami {}.', 48 | 'the {} in a video game.', 49 | 'a sketch of a {}.', 50 | 'a doodle of the {}.', 51 | 'a origami {}.', 52 | 'a low resolution photo of a {}.', 53 | 'the toy {}.', 54 | 'a rendition of the {}.', 55 | 'a photo of the clean {}.', 56 | 'a photo of a large {}.', 57 | 'a rendition of a {}.', 58 | 'a photo of a nice {}.', 59 | 'a photo of a weird {}.', 60 | 'a blurry photo of a {}.', 61 | 'a cartoon {}.', 62 | 'art of a {}.', 63 | 'a sketch of the {}.', 64 | 'a embroidered {}.', 65 | 'a pixelated photo of a {}.', 66 | 'itap of the {}.', 67 | 'a jpeg corrupted photo of the {}.', 68 | 'a good photo of a {}.', 69 | 'a plushie {}.', 70 | 'a photo of the nice {}.', 71 | 'a photo of the small {}.', 72 | 'a photo of the weird {}.', 73 | 'the cartoon {}.', 74 | 'art of the {}.', 75 | 'a drawing of the {}.', 76 | 'a photo of the large {}.', 77 | 'a black and white photo of a {}.', 78 | 'the plushie {}.', 79 | 'a dark photo of a {}.', 80 | 'itap of a {}.', 81 | 'graffiti of the {}.', 82 | 'a toy {}.', 83 | 'itap of my {}.', 84 | 'a photo of a cool {}.', 85 | 'a photo of a small {}.', 86 | 'a tattoo of the {}.', 87 | ] 88 | return prompt_templates 89 | 90 | def prompt_engineering(classnames, topk=1, suffix='.'): 91 | prompt_templates = get_prompt_templates() 92 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 93 | 94 | if isinstance(classnames, list): 95 | classname = random.choice(classnames) 96 | else: 97 | classname = classnames 98 | 99 | return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) -------------------------------------------------------------------------------- /utils/sam_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /utils/sam_utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /utils/sam_utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | --------------------------------------------------------------------------------