├── .gitignore ├── INSTALL.md ├── README.md ├── datasets └── README.md ├── divide_and_conquer ├── cascadepsp.py ├── check.py ├── coco_annotator.py ├── data │ ├── __init__.py │ ├── build.py │ ├── dataset_mapper.py │ ├── datasets │ │ ├── __init__.py │ │ ├── builtin.py │ │ ├── builtin_meta.py │ │ └── coco.py │ ├── detection_utils.py │ └── transforms │ │ ├── __init__.py │ │ ├── augmentation_impl.py │ │ └── transform.py ├── demo.sh ├── demo_dico.ipynb ├── demo_dico.py ├── dino.py ├── divide_conquer.py ├── engine │ ├── __init__.py │ ├── defaults.py │ └── train_loop.py ├── generate_pseudo_masks.sh ├── iterative_merging.py ├── model_zoo │ └── configs │ │ ├── Base-RCNN-FPN.yaml │ │ ├── COCO-Semisupervised │ │ ├── cascade_mask_rcnn_R_50_FPN_100perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_10perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_1perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_20perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_2perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_30perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_40perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_50perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_5perc.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_60perc.yaml │ │ └── cascade_mask_rcnn_R_50_FPN_80perc.yaml │ │ └── CutLER-ImageNet │ │ ├── cascade_mask_rcnn_R_50_FPN.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_demo.yaml │ │ ├── cascade_mask_rcnn_R_50_FPN_self_train.yaml │ │ └── mask_rcnn_R_50_FPN.yaml ├── modeling │ ├── __init__.py │ ├── meta_arch │ │ ├── __init__.py │ │ ├── build.py │ │ └── rcnn.py │ └── roi_heads │ │ ├── __init__.py │ │ ├── custom_cascade_rcnn.py │ │ ├── fast_rcnn.py │ │ └── roi_heads.py ├── solver │ ├── __init__.py │ └── build.py └── structures │ ├── __init__.py │ └── boxes.py ├── docs └── demos │ ├── sa_000001.jpg │ ├── sa_121371.jpg │ ├── sa_163514.jpg │ ├── sa_193210.jpg │ ├── sa_224132.jpg │ ├── sa_234337.jpg │ ├── sa_412497.jpg │ ├── sa_434709.jpg │ ├── sa_479726.jpg │ ├── sa_527160.jpg │ └── sa_562217.jpg ├── promptable_segmentation ├── configs │ └── semantic_sam_only_sa-1b_swinT.yaml ├── datasets │ ├── __init__.py │ ├── build.py │ ├── dataset_mappers │ │ ├── __init__.py │ │ ├── inference_mapper_with_gt.py │ │ └── sam_baseline_dataset_mapper.py │ ├── evaluation │ │ ├── __init__.py │ │ └── interactive_evaluation.py │ ├── registration │ │ ├── __init__.py │ │ ├── register_coco_panoptic_annos_semseg_interactive_jointboxpoint.py │ │ └── register_sam_mnode.py │ └── utils │ │ ├── __init__.py │ │ ├── semseg_loader.py │ │ └── tsv │ │ ├── __init__.py │ │ ├── io_common.py │ │ └── tsv_io.py ├── demo.sh ├── demo_auto_generation.py ├── demo_promptable.ipynb ├── demo_promptable.py ├── examples │ ├── sa_121371.jpg │ ├── sa_412497.jpg │ └── sa_562217.jpg ├── model │ ├── BaseModel.py │ ├── __init__.py │ ├── architectures │ │ ├── __init__.py │ │ └── interactive_mask_dino.py │ ├── backbone │ │ ├── __init__.py │ │ └── swin.py │ ├── body │ │ ├── __init__.py │ │ ├── build.py │ │ ├── decoder │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── interactive_mask_dino.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ └── dino_decoder.py │ │ ├── encoder │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── encoder_deform.py │ │ │ └── ops │ │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ ├── make.sh │ │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ │ ├── setup.py │ │ │ │ ├── src │ │ │ │ ├── cpu │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ ├── cuda │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ ├── ms_deform_attn.h │ │ │ │ └── vision.cpp │ │ │ │ └── test.py │ │ └── general_head.py │ ├── build_semantic_sam.py │ └── utils │ │ ├── __init__.py │ │ └── box_ops.py ├── tasks │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── interactive_idino_m2m.py │ ├── interactive_idino_m2m_auto.py │ └── interactive_predictor.py ├── train.sh ├── train_net.py └── utils │ ├── Config.py │ ├── __init__.py │ ├── arguments.py │ ├── constants.py │ ├── dist.py │ ├── distributed.py │ ├── misc.py │ ├── model.py │ ├── prompt_engineering.py │ ├── sam_utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py │ └── visualizer.py ├── requirements.txt ├── tools ├── promptable_eval.sh └── whole_image_eval.sh └── whole_image_segmentation ├── configs ├── Base-COCO-InstanceSegmentation.yaml └── maskformer2_R50_bs16_50ep.yaml ├── data ├── __init__.py ├── build.py └── datasets │ ├── __init__.py │ ├── builtin.py │ └── builtin_meta.py ├── demo.sh ├── demo_whole_image.ipynb ├── demo_whole_image.py ├── mask2former ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ └── dataset_mappers │ │ ├── __init__.py │ │ ├── sam_instance_tsv_dataset_mapper.py │ │ └── sam_instance_tsv_self_train_dataset_mapper.py ├── evaluation │ ├── __init__.py │ └── coco_evaluation.py ├── maskformer_model.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ └── swin.py │ ├── criterion.py │ ├── matcher.py │ ├── meta_arch │ │ ├── __init__.py │ │ ├── mask_former_head.py │ │ └── per_pixel_baseline.py │ ├── pixel_decoder │ │ ├── __init__.py │ │ ├── fpn.py │ │ ├── msdeformattn.py │ │ └── ops │ │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ │ ├── make.sh │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ │ ├── setup.py │ │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ │ └── test.py │ └── transformer_decoder │ │ ├── __init__.py │ │ ├── mask2former_transformer_decoder.py │ │ ├── maskformer_transformer_decoder.py │ │ ├── position_encoding.py │ │ └── transformer.py ├── test_time_augmentation.py └── utils │ ├── __init__.py │ ├── misc.py │ └── tsv │ ├── __init__.py │ ├── io_common.py │ └── tsv_io.py ├── train.sh └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.json 3 | *.pyc 4 | 5 | eval_output 6 | 7 | promptable_segmentation/data 8 | promptable_segmentation/datasets/coco 9 | 10 | whole_image_segmentation/ex_eval 11 | whole_image_segmentation/ex_train 12 | whole_image_segmentation/output 13 | 14 | # compilation and distribution 15 | *.egg-info/ 16 | build/ 17 | dist/ -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Example conda environment setup 4 | ```bash 5 | conda create --name UnSAM python=3.8 -y 6 | conda activate UnSAM 7 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 8 | 9 | python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git' 10 | pip install git+https://github.com/cocodataset/panopticapi.git 11 | python -m pip install 'git+https://github.com/UX-Decoder/Semantic-SAM.git' 12 | 13 | git clone git@github.com:frank-xwang/UnSAM.git 14 | cd promptable_segmentation/model/body/encoder/ops 15 | sh make.sh 16 | cd whole_image_segmentation/mask2former/modeling/pixel_decoder/ops 17 | sh make.sh 18 | 19 | python -m pip install -r requirements.txt 20 | ``` -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Dataset 2 | ## Preparing Training Dataset 3 | 4 | To generate psuedo-masks for your datasets, please follow our `divide_and_conquer' stage, which generates pseudo-masks for all images in your dataset with unsupervised learning methods. Please run the script `divide_and_conquer/generate_pseudo_masks.sh' — it handles the entire pseudo-mask generation process and will save all the masks produced by our unsupervised labeling pipeline. You can choose to only generate pseudo-masks for a subset of your TRAIN_DATSETS via specifying the start-id and end-id. 5 | 6 | Note that this script expects the following data structure: 7 | 8 | ``` 9 | UNLABELED_IMAGES/ 10 | FOLDER1/ 11 | Image1 12 | ....... 13 | Imagen 14 | FOLDER2/ 15 | Image1 16 | ....... 17 | Imagen 18 | ....... 19 | ``` 20 | 21 | We transfer our training pseudo labels to tsv format for faster dataloading. Please refer to [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM/blob/main/DATASET.md) for a detailed script to transform json format data to tsv format data. 22 | 23 | Please set the environment variable 24 | ```shell 25 | export TRAIN_DATASETS=path_to_your_tsv_directory 26 | ``` 27 | The structure and file names of the tsv directory should follow 28 | ``` 29 | $TRAIN_DATASETS/ 30 | SAM-1.tsv 31 | SAM-1.lineidx 32 | SAM-2.tsv 33 | SAM-2.lineidx 34 | ...... 35 | SAM-N.tsv 36 | SAM-N.lineidx 37 | ``` 38 | 39 | ## Preparing Evaluation Dataset (Whole-Image-Segmentation) 40 | By default we support 7 evaluation datasets for whole-image-segmentation evaluation: SA-1B, COCO, LVIS, ADE20K, EntitySeg, PartImagenet, and PACO. 41 | Please set the root directory environment variable first 42 | ```shell 43 | export DETECTRON2_DATASETS=path_to_your_root_evaluation_directory 44 | ``` 45 | and change the test dataset name in config file 46 | ``` 47 | cfg.DATASETS.TEST: ("unsam_{sa1b, ade20k, entity, paco, partimagenet, coco, lvis}_val") 48 | ``` 49 | The structure and file names of the tsv directory should follow 50 | ``` 51 | $DETECTRON2_DATASETS/ 52 | sa1b/ 53 | images/ 54 | annotations/ 55 | sa1b_val.json 56 | ade/ 57 | images/ 58 | annotations/ 59 | ade_val.json 60 | entity/ 61 | images/ 62 | annotations/ 63 | entityseg_val.json 64 | paco/ 65 | images/ 66 | annotations/ 67 | paco_val.json 68 | partimagenet/ 69 | images/ 70 | annotations/ 71 | partimagenet_val.json 72 | coco/ 73 | val2017/ 74 | annotations/ 75 | instances_val2017.json 76 | lvis/ 77 | images/ 78 | annotations/ 79 | lvis_v1_val.json 80 | ``` 81 | Since ade20k, entity, partimagenet don't have standard image id format, please name each image as {image_id}.jpg. You can adjust them in whole_image_segmentation/data/build.py/get_test_detection_datasets 82 | 83 | ## Preparing Evaluation Dataset (Promptable-Segmentation) 84 | We support COCO evaluation for promptable segmentation. Please set the root directory environment variable first 85 | ```shell 86 | export DETECTRON2_DATASETS=path_to_your_root_evaluation_directory 87 | ``` 88 | and refer to [MaskDINO](https://github.com/IDEA-Research/MaskDINO/blob/main/README.md) to prepare files under the evaluation directory 89 | -------------------------------------------------------------------------------- /divide_and_conquer/cascadepsp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import cv2 4 | import numpy as np 5 | from pycocotools import mask as mask_util 6 | from tqdm import tqdm 7 | 8 | def area(mask): 9 | return np.count_nonzero(mask) / mask.size 10 | 11 | def iou(mask1, mask2): 12 | intersection = np.count_nonzero(np.logical_and(mask1, mask2)) 13 | union = np.count_nonzero(mask1) + np.count_nonzero(mask2) - intersection 14 | if union == 0: return 0 15 | return intersection / union 16 | 17 | def postprocess(args, refiner, annotations, image): 18 | H, W = image.shape[:2] 19 | 20 | start_id = annotations["annotations"][0]['id'] 21 | curr_id = 0 22 | refined_annotations = [] 23 | 24 | for annotation in tqdm(annotations["annotations"]): 25 | mask = mask_util.decode(annotation['segmentation']) 26 | 27 | bbox = annotation['bbox'] 28 | x1, y1, w, h = bbox 29 | x_center = x1 + w / 2 30 | y_center = y1 + h / 2 31 | 32 | longer_side = max(w, h) 33 | x1_resized = int(max(0, x_center - longer_side)) 34 | y1_resized = int(max(0, y_center - longer_side)) 35 | x2_resized = int(min(W, x_center + longer_side)) 36 | y2_resized = int(min(H, y_center + longer_side)) 37 | 38 | image_crop = image[y1_resized:y2_resized, x1_resized:x2_resized, :] 39 | mask_crop = mask[y1_resized:y2_resized, x1_resized:x2_resized] 40 | 41 | L = max(min(max(x2_resized-x1_resized, y2_resized-y1_resized) * args.refine_scale, args.refine_max_L), args.refine_min_L) 42 | refined_mask_crop = refiner.refine(image_crop, mask_crop * 255, fast=True, L=L) 43 | refined_mask_crop = (refined_mask_crop > 128).astype(np.uint8) 44 | 45 | refined_mask = np.zeros((H, W), dtype=np.uint8) 46 | refined_mask[y1_resized:y2_resized, x1_resized:x2_resized] = refined_mask_crop 47 | 48 | if area(refined_mask) < args.min_area_thresh or area(refined_mask) > args.max_area_thresh: 49 | continue 50 | if iou(mask, refined_mask) < args.iou_thresh: 51 | continue 52 | 53 | binary_mask_encoded = mask_util.encode(np.asfortranarray(refined_mask)) 54 | binary_mask_encoded['counts'] = binary_mask_encoded['counts'].decode('ascii') 55 | 56 | annotation['segmentation'] = binary_mask_encoded 57 | annotation['bbox'] = mask_util.toBbox(binary_mask_encoded).tolist() 58 | annotation['area'] = mask_util.area(binary_mask_encoded).tolist() 59 | annotation['id'] = start_id + curr_id 60 | curr_id += 0 61 | 62 | refined_annotations.append(annotation) 63 | 64 | annotations["annotations"] = refined_annotations 65 | return annotations -------------------------------------------------------------------------------- /divide_and_conquer/check.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pycocotools.coco import COCO 4 | from pycocotools.cocoeval import COCOeval 5 | from pycocotools import mask as mask_util 6 | 7 | import json 8 | import os 9 | from tqdm import tqdm 10 | 11 | 12 | def merge_dt_list(list): 13 | merged_data = { 14 | "info": { 15 | "year": 2023, 16 | "version": "1", 17 | "date_created": "no need record" 18 | }, 19 | "images": [], 20 | "annotations": [], 21 | "licenses": [ 22 | { 23 | "id": 1, 24 | "name": "Unknown", 25 | "url": "" 26 | } 27 | ], 28 | "categories": [ 29 | { 30 | "id": 1, 31 | "name": "hd", 32 | "supercategory": "" 33 | } 34 | ] 35 | } 36 | 37 | annotation_id_counter = 1 38 | 39 | image_ids = [] 40 | 41 | for ann_path in tqdm(os.listdir(list)): 42 | if not ann_path.startswith('p'): continue 43 | 44 | ann = json.load(open(os.path.join(list, ann_path))) 45 | for data in ann["annotations"]: 46 | 47 | if data["image_id"] not in image_ids: 48 | image_ids.append(data["image_id"]) 49 | merged_data["images"].append( 50 | { 51 | "id": data["image_id"], 52 | "height": data["segmentation"]["size"][0], 53 | "width": data["segmentation"]["size"][1], 54 | } 55 | ) 56 | 57 | ann_area = mask_util.area(data["segmentation"]).tolist() 58 | 59 | # one json only has one image in SA-1B 60 | data["id"] = annotation_id_counter 61 | annotation_id_counter += 1 62 | data["category_id"] = 1 63 | data["iscrowd"] = 0 64 | data["score"] = 1 65 | data["area"] = ann_area 66 | 67 | # Append the updated annotation to the merged_data 68 | merged_data["annotations"].append(data) 69 | 70 | # Save the merged data to a new JSON file 71 | output_path = "merged_dt.json" 72 | print(f"Saving merged data to {output_path}") 73 | with open(output_path, 'w') as output_file: 74 | json.dump(merged_data, output_file, indent=4) 75 | return output_path 76 | 77 | def merge_gt_files(coco_dt): 78 | merged_data = { 79 | "info": { 80 | "year": 2023, 81 | "version": "1", 82 | "date_created": "no need record" 83 | }, 84 | "images": [], 85 | "annotations": [], 86 | "licenses": [ 87 | { 88 | "id": 1, 89 | "name": "Unknown", 90 | "url": "" 91 | } 92 | ], 93 | "categories": [ 94 | { 95 | "id": 1, 96 | "name": "hd", 97 | "supercategory": "" 98 | } 99 | ] 100 | } 101 | 102 | annotation_id_counter = 1 103 | 104 | for dt_img in tqdm(coco_dt.dataset["images"]): 105 | name = "sa_"+str(dt_img["id"])+".json" 106 | f = open(os.path.join("datasets/sa1b/annotations/val_gt_2", name)) 107 | data = json.load(f) 108 | if data == None: continue 109 | 110 | # one json only has one image in SA-1B 111 | if type(data["image"]) is not list: 112 | # SA only has image_id, not id 113 | data["image"]["id"] = data["image"]["image_id"] 114 | merged_data["images"].append(data["image"]) 115 | 116 | # Update annotation IDs and image IDs 117 | for annotation in data["annotations"]: 118 | annotation["id"] = annotation_id_counter 119 | annotation_id_counter += 1 120 | annotation["score"] = 1.0 121 | annotation["category_id"] = 1 122 | annotation["iscrowd"] = 0 123 | 124 | # one json only has one image in SA-1B, and annotations don't have image_id 125 | if type(data["image"]) is not list: 126 | annotation["image_id"] = data["image"]["id"] 127 | 128 | # Append the updated annotation to the merged_data 129 | merged_data["annotations"].append(annotation) 130 | 131 | # Save the merged data to a new JSON file 132 | output_path = "merged_gt.json" 133 | with open(output_path, 'w') as output_file: 134 | json.dump(merged_data, output_file, indent=4) 135 | return output_path 136 | 137 | 138 | 139 | if __name__ == "__main__": 140 | 141 | parser = argparse.ArgumentParser('') 142 | parser.add_argument('--predict-directory', type=str, default="divide_and_conquer/pseudo_masks", help='predict-directory') 143 | parser.add_argument('--iou-type', type=str, default="segm", help='iou_type') 144 | 145 | args = parser.parse_args() 146 | 147 | path = merge_dt_list(args.predict_directory) 148 | 149 | coco_dt = COCO(path) 150 | coco_gt = COCO(merge_gt_files(coco_dt)) 151 | 152 | # for ann in coco_dt.dataset["annotations"]: 153 | # ann["score"] = 1 154 | 155 | coco_eval = COCOeval(coco_gt, coco_dt, args.iou_type) 156 | coco_eval.params.useCats = 0 157 | coco_eval.params.maxDets = [1, 100, 1000] 158 | 159 | coco_eval.evaluate() 160 | coco_eval.accumulate() 161 | coco_eval.summarize() 162 | ap = coco_eval.stats[:6] 163 | ar = coco_eval.stats[6:12] 164 | 165 | mAP_copypaste = ( 166 | f'{ap[0]*100:.2f} {ap[1]*100:.2f} {ap[2]*100:.2f} {ap[3]*100:.2f} {ap[4]*100:.2f} {ap[5]*100:.2f}') 167 | mAR_copypaste = ( 168 | f'{ar[0]*100:.2f} {ar[1]*100:.2f} {ar[3]*100:.2f} {ar[4]*100:.2f} {ar[5]*100:.2f} {ar[2]*100:.2f}') 169 | 170 | print("mAP copy-paste: ", mAP_copypaste) 171 | print("mAR copy-paste: ", mAR_copypaste) 172 | print("All in one copy-paste: ", mAP_copypaste + " " + mAR_copypaste) 173 | print("num of masks: ", len(coco_dt.dataset['annotations'])) -------------------------------------------------------------------------------- /divide_and_conquer/coco_annotator.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import re 4 | from pycocotools import mask 5 | import pycocotools.mask as mask_util 6 | 7 | def create_image_info(image_id, file_name, image_size, 8 | date_captured=datetime.datetime.utcnow().isoformat(' '), 9 | license_id=1, coco_url="", flickr_url=""): 10 | """Return image_info in COCO style 11 | Args: 12 | image_id: the image ID 13 | file_name: the file name of each image 14 | image_size: image size in the format of (width, height) 15 | date_captured: the date this image info is created 16 | license: license of this image 17 | coco_url: url to COCO images if there is any 18 | flickr_url: url to flickr if there is any 19 | """ 20 | image_info = { 21 | "id": image_id, 22 | "file_name": file_name, 23 | "width": image_size[1], 24 | "height": image_size[0], 25 | "date_captured": date_captured, 26 | "license": license_id, 27 | "coco_url": coco_url, 28 | "flickr_url": flickr_url 29 | } 30 | return image_info 31 | 32 | 33 | def create_annotation_info(annotation_id, image_id, category_info, binary_mask, 34 | image_size=None, bounding_box=None): 35 | """Return annotation info in COCO style 36 | Args: 37 | annotation_id: the annotation ID 38 | image_id: the image ID 39 | category_info: the information on categories 40 | binary_mask: a 2D binary numpy array where '1's represent the object 41 | file_name: the file name of each image 42 | image_size: image size in the format of (width, height) 43 | bounding_box: the bounding box for detection task. If bounding_box is not provided, 44 | we will generate one according to the binary mask. 45 | """ 46 | binary_mask_encoded = mask.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 47 | 48 | area = mask.area(binary_mask_encoded) 49 | if area < 1: 50 | return None 51 | 52 | if bounding_box is None: 53 | bounding_box = mask.toBbox(binary_mask_encoded) 54 | 55 | rle = mask_util.encode(np.array(binary_mask[...,None], order="F", dtype="uint8"))[0] 56 | rle['counts'] = rle['counts'].decode('ascii') 57 | segmentation = rle 58 | 59 | annotation_info = { 60 | "id": annotation_id, 61 | "image_id": image_id, 62 | "category_id": category_info["id"], 63 | "iscrowd": 0, 64 | "area": area.tolist(), 65 | "bbox": bounding_box.tolist(), 66 | "segmentation": segmentation, 67 | "width": binary_mask.shape[1], 68 | "height": binary_mask.shape[0], 69 | } 70 | 71 | return annotation_info 72 | 73 | # necessay info used for coco style annotations 74 | INFO = { 75 | #"description": "ImageNet-1K: pseudo-masks with MaskCut", 76 | #"url": "https://github.com/facebookresearch/CutLER", 77 | "version": "1.0", 78 | "year": 2023, 79 | #"contributor": "Xudong Wang", 80 | "date_created": datetime.datetime.utcnow().isoformat(' ') 81 | } 82 | 83 | LICENSES = [ 84 | { 85 | "id": 1, 86 | "name": "Apache License", 87 | #"url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE" 88 | } 89 | ] 90 | 91 | # only one class, i.e. foreground 92 | CATEGORIES = [ 93 | { 94 | 'id': 1, 95 | 'name': 'fg', 96 | 'supercategory': 'fg', 97 | }, 98 | ] 99 | 100 | convert = lambda text: int(text) if text.isdigit() else text.lower() 101 | natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 102 | 103 | output = { 104 | "info": INFO, 105 | "licenses": LICENSES, 106 | "categories": CATEGORIES, 107 | "image": {}, 108 | "annotations": []} 109 | 110 | category_info = { 111 | "is_crowd": 0, 112 | "id": 1 113 | } -------------------------------------------------------------------------------- /divide_and_conquer/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from . import datasets # ensure the builtin datasets are registered 4 | from .detection_utils import * # isort:skip 5 | from .build import ( 6 | build_batch_data_loader, 7 | build_detection_train_loader, 8 | build_detection_test_loader, 9 | get_detection_dataset_dicts, 10 | load_proposals_into_dataset, 11 | print_instances_class_histogram, 12 | ) 13 | from detectron2.data.common import * 14 | 15 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /divide_and_conquer/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json 3 | from .builtin import ( 4 | register_all_imagenet, 5 | register_all_uvo, 6 | register_all_coco_ca, 7 | register_all_coco_semi, 8 | register_all_lvis, 9 | register_all_voc, 10 | register_all_cross_domain, 11 | register_all_kitti, 12 | register_all_objects365, 13 | register_all_openimages, 14 | ) 15 | 16 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /divide_and_conquer/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/__init__.py 3 | 4 | from fvcore.transforms.transform import * 5 | from .transform import * 6 | from detectron2.data.transforms.augmentation import * 7 | from .augmentation_impl import * 8 | 9 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 10 | 11 | 12 | from detectron2.utils.env import fixup_module_metadata 13 | 14 | fixup_module_metadata(__name__, globals(), __all__) 15 | del fixup_module_metadata -------------------------------------------------------------------------------- /divide_and_conquer/demo.sh: -------------------------------------------------------------------------------- 1 | python demo_dico.py --input ../docs/demos/sa_234337.jpg \ 2 | --output demo.jpg \ 3 | --postprocess true \ 4 | --opts MODEL.WEIGHTS cutler_cascade_final.pth \ -------------------------------------------------------------------------------- /divide_and_conquer/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .train_loop import * 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | 7 | from .defaults import * -------------------------------------------------------------------------------- /divide_and_conquer/generate_pseudo_masks.sh: -------------------------------------------------------------------------------- 1 | python divide_conquer.py \ 2 | --input-dir /PATH/TO/DATASETS \ 3 | --output-dir pseudo_masks \ 4 | --start-id 50 \ 5 | --end-id 100 \ 6 | --preprocess True \ 7 | --opts MODEL.WEIGHTS cutler_cascade_final.pth 8 | 9 | python divide_conquer.py \ 10 | --input-dir /PATH/TO/DATASETS \ 11 | --output-dir pseudo_masks\ 12 | --postprocess True \ 13 | --opts MODEL.WEIGHTS cutler_cascade_final.pth -------------------------------------------------------------------------------- /divide_and_conquer/iterative_merging.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import numpy as np 3 | import torch 4 | 5 | def merged_clusters(i, j, clusters): 6 | c1, c2 = clusters[i], clusters[j] 7 | weighted_sum = ((c1['feature'] + c2['feature']) / (c1['num_of_patch'] + c2['num_of_patch'])).float() 8 | #weighted_sum = ((c1['num_of_patch']*c1['feature'] + c2['num_of_patch']*c2['feature']) / (c1['num_of_patch'] + c2['num_of_patch'])).float() 9 | return { 10 | 'feature': c1['feature'] + c2['feature'], 11 | 'normalized_feature': F.normalize(weighted_sum, dim=0), 12 | 'mask': (c1['mask'] + c2['mask']) > 0, 13 | 'num_of_patch': c1['num_of_patch'] + c2['num_of_patch'], 14 | 'neighbors': c1['neighbors'].union(c2['neighbors']).difference(set([i, j])) 15 | } 16 | 17 | def iterative_merge(features, threshes, min_size=4): 18 | 19 | clusters = [] 20 | similarities = {} 21 | H, W = features.shape[:2] 22 | 23 | cluster_idx = 0 24 | for y in range(H): 25 | for x in range(W): 26 | mask = np.zeros((H, W)) 27 | mask[y, x] = 1 28 | clusters.append({ 29 | 'feature': features[y, x], 30 | 'normalized_feature': F.normalize(features[y, x].float(), dim=0), 31 | 'mask': mask, 32 | 'num_of_patch': 1, 33 | 'neighbors': set() 34 | }) 35 | 36 | if (cluster_idx % W) != 0: 37 | clusters[cluster_idx]['neighbors'].add(cluster_idx-1) 38 | clusters[cluster_idx-1]['neighbors'].add(cluster_idx) 39 | similarities[(cluster_idx-1, cluster_idx)] = \ 40 | torch.dot(clusters[cluster_idx-1]['normalized_feature'], clusters[cluster_idx]['normalized_feature']).item() 41 | if (cluster_idx - W) >= 0: 42 | clusters[cluster_idx]['neighbors'].add(cluster_idx-W) 43 | clusters[cluster_idx-W]['neighbors'].add(cluster_idx) 44 | similarities[(cluster_idx-W, cluster_idx)] = \ 45 | torch.dot(clusters[cluster_idx-W]['normalized_feature'], clusters[cluster_idx]['normalized_feature']).item() 46 | 47 | cluster_idx += 1 48 | 49 | all_masks = [] 50 | for thresh in threshes: 51 | while len(similarities): 52 | i, j = max(similarities, key=similarities.get) 53 | if similarities[(i, j)] < thresh: break 54 | 55 | merged = merged_clusters(i, j, clusters) 56 | clusters.append(merged) 57 | 58 | del similarities[(i, j)] 59 | for neighbor in merged['neighbors']: 60 | if i in clusters[neighbor]['neighbors']: 61 | if neighbor < i: del similarities[(neighbor, i)] 62 | else: del similarities[(i, neighbor)] 63 | clusters[neighbor]['neighbors'].discard(i) 64 | if j in clusters[neighbor]['neighbors']: 65 | if neighbor < j: del similarities[(neighbor, j)] 66 | else: del similarities[(j, neighbor)] 67 | clusters[neighbor]['neighbors'].discard(j) 68 | 69 | similarities[(neighbor, cluster_idx)] = \ 70 | torch.dot(clusters[neighbor]['normalized_feature'], clusters[cluster_idx]['normalized_feature']).item() 71 | clusters[neighbor]['neighbors'].add(cluster_idx) 72 | 73 | cluster_idx += 1 74 | 75 | single_level_masks = [] 76 | counted_cluster = set() 77 | for (m, n) in similarities: 78 | if m not in counted_cluster: 79 | counted_cluster.add(m) 80 | single_level_masks.append(clusters[m]['mask']) if clusters[m]['num_of_patch'] >= min_size else None 81 | if n not in counted_cluster: 82 | counted_cluster.add(n) 83 | single_level_masks.append(clusters[n]['mask']) if clusters[n]['num_of_patch'] >= min_size else None 84 | all_masks.append(np.stack(single_level_masks)) if len(single_level_masks) else None 85 | 86 | return all_masks -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_100perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_2017_train",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (60000, 80000) 28 | MAX_ITER: 90000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/100perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_10perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_10perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (6000, 8000) 28 | MAX_ITER: 9000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/10perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_1perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_1perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (2400, 3200) 28 | MAX_ITER: 3600 29 | WARMUP_FACTOR: 0.001 30 | WARMUP_ITERS: 1000 31 | BASE_LR_MULTIPLIER: 4 32 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 33 | INPUT: 34 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 35 | MAX_SIZE_TRAIN: 1333 36 | MASK_FORMAT: "bitmask" 37 | FORMAT: "RGB" 38 | TEST: 39 | PRECISE_BN: 40 | ENABLED: True 41 | EVAL_PERIOD: 5000 42 | OUTPUT_DIR: "output/1perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_20perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_20perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (12000, 16000) 28 | MAX_ITER: 18000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/20perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_2perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_2perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (2400, 3200) 28 | MAX_ITER: 3600 29 | WARMUP_FACTOR: 0.001 30 | WARMUP_ITERS: 1000 31 | BASE_LR_MULTIPLIER: 4 32 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 33 | INPUT: 34 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 35 | MAX_SIZE_TRAIN: 1333 36 | MASK_FORMAT: "bitmask" 37 | FORMAT: "RGB" 38 | TEST: 39 | PRECISE_BN: 40 | ENABLED: True 41 | EVAL_PERIOD: 5000 42 | OUTPUT_DIR: "output/2perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_30perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_30perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (18000, 24000) 28 | MAX_ITER: 27000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/30perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_40perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_40perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (24000, 32000) 28 | MAX_ITER: 36000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/40perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_50perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_50perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (30000, 40000) 28 | MAX_ITER: 45000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/50perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_5perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_5perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (3000, 4000) 28 | MAX_ITER: 4500 29 | WARMUP_FACTOR: 0.001 30 | WARMUP_ITERS: 1000 31 | BASE_LR_MULTIPLIER: 4 32 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 33 | INPUT: 34 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 35 | MAX_SIZE_TRAIN: 1333 36 | MASK_FORMAT: "bitmask" 37 | FORMAT: "RGB" 38 | TEST: 39 | PRECISE_BN: 40 | ENABLED: True 41 | EVAL_PERIOD: 5000 42 | OUTPUT_DIR: "output/5perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_60perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_60perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (36000, 48000) 28 | MAX_ITER: 54000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/60perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_80perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_80perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (48000, 64000) 28 | MAX_ITER: 72000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/80perc" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 0 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 14 | MASK_ON: True 15 | BACKBONE: 16 | FREEZE_AT: 0 17 | RESNETS: 18 | DEPTH: 50 19 | NORM: "SyncBN" 20 | STRIDE_IN_1X1: False 21 | FPN: 22 | NORM: "SyncBN" 23 | ROI_BOX_HEAD: 24 | CLS_AGNOSTIC_BBOX_REG: True 25 | ROI_HEADS: 26 | NAME: CustomCascadeROIHeads 27 | NUM_CLASSES: 1 28 | SCORE_THRESH_TEST: 0.0 29 | POSITIVE_FRACTION: 0.25 30 | USE_DROPLOSS: True 31 | DROPLOSS_IOU_THRESH: 0.01 32 | RPN: 33 | POST_NMS_TOPK_TRAIN: 4000 34 | NMS_THRESH: 0.65 35 | DATASETS: 36 | TRAIN: ("imagenet_train",) 37 | SOLVER: 38 | IMS_PER_BATCH: 16 39 | BASE_LR: 0.01 40 | WEIGHT_DECAY: 0.00005 41 | STEPS: (80000,) 42 | MAX_ITER: 160000 43 | GAMMA: 0.02 44 | CLIP_GRADIENTS: 45 | CLIP_TYPE: norm 46 | CLIP_VALUE: 1.0 47 | ENABLED: true 48 | NORM_TYPE: 2.0 49 | AMP: 50 | ENABLED: True 51 | INPUT: 52 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 53 | MAX_SIZE_TRAIN: 1333 54 | MASK_FORMAT: "bitmask" 55 | FORMAT: "RGB" 56 | TEST: 57 | PRECISE_BN: 58 | ENABLED: True 59 | NUM_ITER: 200 60 | DETECTIONS_PER_IMAGE: 100 61 | OUTPUT_DIR: "output/" -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_demo.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 0 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 14 | MASK_ON: True 15 | BACKBONE: 16 | FREEZE_AT: 0 17 | RESNETS: 18 | DEPTH: 50 19 | NORM: "SyncBN" 20 | STRIDE_IN_1X1: False 21 | FPN: 22 | NORM: "SyncBN" 23 | ROI_BOX_HEAD: 24 | CLS_AGNOSTIC_BBOX_REG: True 25 | ROI_HEADS: 26 | NAME: CustomCascadeROIHeads 27 | NUM_CLASSES: 1 28 | SCORE_THRESH_TEST: 0.0 29 | POSITIVE_FRACTION: 0.25 30 | USE_DROPLOSS: True 31 | DROPLOSS_IOU_THRESH: 0.01 32 | RPN: 33 | POST_NMS_TOPK_TRAIN: 4000 34 | NMS_THRESH: 0.65 35 | DATASETS: 36 | TRAIN: ("imagenet_train",) 37 | TEST: ("imagenet_train",) 38 | SOLVER: 39 | IMS_PER_BATCH: 16 40 | BASE_LR: 0.01 41 | WEIGHT_DECAY: 0.00005 42 | STEPS: (80000,) 43 | MAX_ITER: 160000 44 | GAMMA: 0.02 45 | CLIP_GRADIENTS: 46 | CLIP_TYPE: norm 47 | CLIP_VALUE: 1.0 48 | ENABLED: true 49 | NORM_TYPE: 2.0 50 | AMP: 51 | ENABLED: True 52 | INPUT: 53 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 54 | MAX_SIZE_TRAIN: 1333 55 | MASK_FORMAT: "bitmask" 56 | FORMAT: "RGB" 57 | TEST: 58 | PRECISE_BN: 59 | ENABLED: True 60 | NUM_ITER: 200 61 | DETECTIONS_PER_IMAGE: 100 62 | OUTPUT_DIR: "output/" 63 | -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_self_train.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.5 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 2 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r1.pth' # round 1 14 | # WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r2.pth' # round 2 15 | MASK_ON: True 16 | BACKBONE: 17 | FREEZE_AT: 0 18 | RESNETS: 19 | DEPTH: 50 20 | NORM: "SyncBN" 21 | STRIDE_IN_1X1: False 22 | FPN: 23 | NORM: "SyncBN" 24 | ROI_BOX_HEAD: 25 | CLS_AGNOSTIC_BBOX_REG: True 26 | ROI_HEADS: 27 | NAME: CustomCascadeROIHeads 28 | NUM_CLASSES: 1 29 | SCORE_THRESH_TEST: 0.0 30 | POSITIVE_FRACTION: 0.25 31 | USE_DROPLOSS: False 32 | DROPLOSS_IOU_THRESH: 0.01 33 | DATASETS: 34 | TRAIN: ("imagenet_train_r1",) # round 1 35 | # TRAIN: ("imagenet_train_r2",) # round 2 36 | SOLVER: 37 | IMS_PER_BATCH: 16 38 | BASE_LR: 0.005 39 | STEPS: (79999,) 40 | MAX_ITER: 80000 41 | GAMMA: 1.0 42 | CLIP_GRADIENTS: 43 | CLIP_TYPE: norm 44 | CLIP_VALUE: 1.0 45 | ENABLED: true 46 | NORM_TYPE: 2.0 47 | AMP: 48 | ENABLED: True 49 | INPUT: 50 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 51 | MAX_SIZE_TRAIN: 1333 52 | MASK_FORMAT: "bitmask" 53 | FORMAT: "RGB" 54 | TEST: 55 | PRECISE_BN: 56 | ENABLED: True 57 | NUM_ITER: 200 58 | DETECTIONS_PER_IMAGE: 100 59 | OUTPUT_DIR: "output/self-train-r1/" # round 1 60 | # OUTPUT_DIR: "output/self-train-r2/" # round 2 -------------------------------------------------------------------------------- /divide_and_conquer/model_zoo/configs/CutLER-ImageNet/mask_rcnn_R_50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | MODEL: 10 | PIXEL_MEAN: [123.675, 116.280, 103.530] 11 | PIXEL_STD: [58.395, 57.120, 57.375] 12 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 13 | MASK_ON: True 14 | BACKBONE: 15 | FREEZE_AT: 0 16 | RESNETS: 17 | DEPTH: 50 18 | NORM: "SyncBN" 19 | STRIDE_IN_1X1: False 20 | FPN: 21 | NORM: "SyncBN" 22 | ROI_HEADS: 23 | NAME: "CustomStandardROIHeads" 24 | NUM_CLASSES: 1 25 | SCORE_THRESH_TEST: 0.0 26 | USE_DROPLOSS: True 27 | DROPLOSS_IOU_THRESH: 0.01 28 | RPN: 29 | POST_NMS_TOPK_TRAIN: 4000 30 | NMS_THRESH: 0.65 31 | DATASETS: 32 | TRAIN: ("imagenet_train",) 33 | SOLVER: 34 | IMS_PER_BATCH: 16 35 | BASE_LR: 0.01 36 | WEIGHT_DECAY: 0.00005 37 | STEPS: (80000,) 38 | MAX_ITER: 160000 39 | CLIP_GRADIENTS: 40 | CLIP_TYPE: norm 41 | CLIP_VALUE: 1.0 42 | ENABLED: true 43 | NORM_TYPE: 2.0 44 | INPUT: 45 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 46 | MAX_SIZE_TRAIN: 1333 47 | MASK_FORMAT: "bitmask" 48 | FORMAT: "RGB" 49 | TEST: 50 | PRECISE_BN: 51 | ENABLED: True 52 | OUTPUT_DIR: "output/" -------------------------------------------------------------------------------- /divide_and_conquer/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .roi_heads import ( 4 | ROI_HEADS_REGISTRY, 5 | ROIHeads, 6 | CustomStandardROIHeads, 7 | FastRCNNOutputLayers, 8 | build_roi_heads, 9 | ) 10 | from .roi_heads.custom_cascade_rcnn import CustomCascadeROIHeads 11 | from .roi_heads.fast_rcnn import FastRCNNOutputLayers 12 | from .meta_arch.rcnn import GeneralizedRCNN, ProposalNetwork 13 | from .meta_arch.build import build_model 14 | 15 | _EXCLUDE = {"ShapeSpec"} 16 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] -------------------------------------------------------------------------------- /divide_and_conquer/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/meta_arch/__init__.py 4 | 5 | from .build import META_ARCH_REGISTRY, build_model # isort:skip 6 | 7 | __all__ = list(globals().keys()) 8 | -------------------------------------------------------------------------------- /divide_and_conquer/modeling/meta_arch/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/meta_arch/build.py 3 | 4 | import torch 5 | 6 | from detectron2.utils.logger import _log_api_usage 7 | from detectron2.utils.registry import Registry 8 | 9 | META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip 10 | META_ARCH_REGISTRY.__doc__ = """ 11 | Registry for meta-architectures, i.e. the whole model. 12 | 13 | The registered object will be called with `obj(cfg)` 14 | and expected to return a `nn.Module` object. 15 | """ 16 | 17 | 18 | def build_model(cfg): 19 | """ 20 | Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. 21 | Note that it does not load any weights from ``cfg``. 22 | """ 23 | meta_arch = cfg.MODEL.META_ARCHITECTURE 24 | model = META_ARCH_REGISTRY.get(meta_arch)(cfg) 25 | model.to(torch.device(cfg.MODEL.DEVICE)) 26 | _log_api_usage("modeling.meta_arch." + meta_arch) 27 | return model 28 | -------------------------------------------------------------------------------- /divide_and_conquer/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .roi_heads import ( 4 | ROI_HEADS_REGISTRY, 5 | ROIHeads, 6 | Res5ROIHeads, 7 | CustomStandardROIHeads, 8 | build_roi_heads, 9 | select_foreground_proposals, 10 | ) 11 | from .custom_cascade_rcnn import CustomCascadeROIHeads 12 | from .fast_rcnn import FastRCNNOutputLayers 13 | 14 | from . import custom_cascade_rcnn # isort:skip 15 | 16 | __all__ = list(globals().keys()) 17 | -------------------------------------------------------------------------------- /divide_and_conquer/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /divide_and_conquer/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .boxes import pairwise_iou_max_scores 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | 7 | 8 | from detectron2.utils.env import fixup_module_metadata 9 | 10 | fixup_module_metadata(__name__, globals(), __all__) 11 | del fixup_module_metadata 12 | -------------------------------------------------------------------------------- /divide_and_conquer/structures/boxes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/structures/boxes.py 3 | 4 | import torch 5 | 6 | def pairwise_iou_max_scores(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Given two lists of boxes of size N and M, compute the IoU 9 | (intersection over union) between **all** N x M pairs of boxes. 10 | The box order must be (xmin, ymin, xmax, ymax). 11 | 12 | Args: 13 | boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. 14 | 15 | Returns: 16 | Tensor: IoU, sized [N,M]. 17 | """ 18 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # [N] 19 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # [M] 20 | 21 | width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( 22 | boxes1[:, None, :2], boxes2[:, :2] 23 | ) # [N,M,2] 24 | 25 | width_height.clamp_(min=0) # [N,M,2] 26 | inter = width_height.prod(dim=2) # [N,M] 27 | 28 | # handle empty boxes 29 | iou = torch.where( 30 | inter > 0, 31 | inter / (area1[:, None] + area2 - inter), 32 | torch.zeros(1, dtype=inter.dtype, device=inter.device), 33 | ) 34 | iou_max, _ = torch.max(iou, dim=1) 35 | return iou_max -------------------------------------------------------------------------------- /docs/demos/sa_000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_000001.jpg -------------------------------------------------------------------------------- /docs/demos/sa_121371.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_121371.jpg -------------------------------------------------------------------------------- /docs/demos/sa_163514.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_163514.jpg -------------------------------------------------------------------------------- /docs/demos/sa_193210.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_193210.jpg -------------------------------------------------------------------------------- /docs/demos/sa_224132.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_224132.jpg -------------------------------------------------------------------------------- /docs/demos/sa_234337.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_234337.jpg -------------------------------------------------------------------------------- /docs/demos/sa_412497.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_412497.jpg -------------------------------------------------------------------------------- /docs/demos/sa_434709.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_434709.jpg -------------------------------------------------------------------------------- /docs/demos/sa_479726.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_479726.jpg -------------------------------------------------------------------------------- /docs/demos/sa_527160.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_527160.jpg -------------------------------------------------------------------------------- /docs/demos/sa_562217.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/docs/demos/sa_562217.jpg -------------------------------------------------------------------------------- /promptable_segmentation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import registration 2 | from .build import * -------------------------------------------------------------------------------- /promptable_segmentation/datasets/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .sam_baseline_dataset_mapper import build_transform_gen as sam_transform_gen 3 | from .sam_baseline_dataset_mapper import SamBaselineDatasetMapper 4 | from .inference_mapper_with_gt import CoCoInferenceDatasetMapper -------------------------------------------------------------------------------- /promptable_segmentation/datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_evaluation import * -------------------------------------------------------------------------------- /promptable_segmentation/datasets/registration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Micorsoft, Inc. and its affiliates. 2 | from . import ( 3 | register_coco_panoptic_annos_semseg_interactive_jointboxpoint, 4 | register_sam_mnode, 5 | # register_object365_od, 6 | # register_sam, 7 | ) 8 | -------------------------------------------------------------------------------- /promptable_segmentation/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .semseg_loader import * -------------------------------------------------------------------------------- /promptable_segmentation/datasets/utils/semseg_loader.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import scipy.io 3 | import numpy as np 4 | 5 | def load_semseg(filename, loader_type): 6 | if loader_type == 'PIL': 7 | semseg = np.array(Image.open(filename), dtype=np.int) 8 | elif loader_type == 'MAT': 9 | semseg = scipy.io.loadmat(filename)['LabelMap'] 10 | return semseg -------------------------------------------------------------------------------- /promptable_segmentation/datasets/utils/tsv/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-16 16:56:22 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2021-08-16 17:00:28 6 | 7 | from .io_common import FileProgressingbar, img_from_base64, generate_lineidx 8 | from .tsv_io import TSVFile 9 | 10 | __all__ = [ 11 | 'FileProgressingbar', 'img_from_base64', 'generate_lineidx', 'TSVFile' 12 | ] -------------------------------------------------------------------------------- /promptable_segmentation/datasets/utils/tsv/io_common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-13 14:35:27 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2022-04-24 11:38:58 6 | 7 | import os 8 | import base64 9 | from io import BytesIO 10 | from PIL import Image 11 | 12 | import cv2 13 | import yaml 14 | import progressbar 15 | import numpy as np 16 | import torchvision.transforms as T 17 | 18 | class FileProgressingbar: 19 | fileobj = None 20 | pbar = None 21 | def __init__(self, fileobj, msg): 22 | fileobj.seek(0, os.SEEK_END) 23 | flen = fileobj.tell() 24 | fileobj.seek(0, os.SEEK_SET) 25 | self.fileobj = fileobj 26 | widgets = [msg, progressbar.AnimatedMarker(), ' ', progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()] 27 | self.pbar = progressbar.ProgressBar(widgets=widgets, maxval=flen).start() 28 | 29 | def update(self): 30 | self.pbar.update(self.fileobj.tell()) 31 | 32 | 33 | def img_from_base64(imagestring): 34 | jpgbytestring = base64.b64decode(imagestring) 35 | image = BytesIO(jpgbytestring) 36 | image = Image.open(image).convert("RGB") 37 | return image 38 | 39 | # jpgbytestring = base64.b64decode(imagestring) 40 | # nparr = np.frombuffer(jpgbytestring, np.uint8) 41 | # try: 42 | # r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 43 | # # r = cv2.cvtColor(r, cv2.COLOR_BGR2RGB) 44 | # return r 45 | # except: 46 | # return None 47 | 48 | 49 | def generate_lineidx(filein, idxout): 50 | assert not os.path.isfile(idxout) 51 | with open(filein, 'r') as tsvin, open(idxout, 'w') as tsvout: 52 | bar = FileProgressingbar(tsvin, 'Generating lineidx {0}: '.format(idxout)) 53 | fsize = os.fstat(tsvin.fileno()).st_size 54 | fpos = 0 55 | while fpos != fsize: 56 | tsvout.write(str(fpos)+"\n") 57 | tsvin.readline() 58 | fpos = tsvin.tell() 59 | bar.update() 60 | -------------------------------------------------------------------------------- /promptable_segmentation/datasets/utils/tsv/tsv_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-13 14:26:21 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2022-08-17 00:57:51 6 | import time 7 | import os 8 | import os.path as op 9 | from .io_common import generate_lineidx, FileProgressingbar 10 | 11 | 12 | class TSVFile(object): 13 | def __init__(self, tsv_file, silence=True): 14 | self.tsv_file = tsv_file 15 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' 16 | 17 | self.label_file = op.splitext(tsv_file)[0] + '.label' 18 | self.label_lineidx = op.splitext(tsv_file)[0] + '.label.lineidx' 19 | 20 | if os.path.exists(self.label_file): 21 | self.split_label = True 22 | else: 23 | self.split_label = False 24 | 25 | self._fp = None 26 | self._lineidx = None 27 | 28 | self._label_fp = None 29 | self._label_lineidx = None 30 | 31 | self.pid = None 32 | self.silence = silence 33 | self._ensure_lineidx_loaded() 34 | 35 | def num_rows(self): 36 | return len(self._lineidx) 37 | 38 | def seek(self, idx): 39 | self._ensure_tsv_opened() 40 | pos = self._lineidx[idx] 41 | self._fp.seek(pos) 42 | tsv_info = [s.strip() for s in self._fp.readline().split('\t')] 43 | 44 | if self.split_label: 45 | label_pos = self._label_lineidx[idx] 46 | self._label_fp.seek(label_pos) 47 | label_info = [s.strip() for s in self._label_fp.readline().split('\t')] 48 | 49 | assert tsv_info[0] == label_info[0] 50 | tsv_info = [tsv_info[0], label_info[-1], tsv_info[-1]] 51 | 52 | return tsv_info 53 | 54 | def close(self): 55 | if self._fp is not None: 56 | self._fp.close() 57 | del self._fp 58 | del self._lineidx 59 | 60 | self._fp = None 61 | self._lineidx = None 62 | 63 | def _ensure_lineidx_loaded(self): 64 | if not op.isfile(self.lineidx) and not op.islink(self.lineidx): 65 | generate_lineidx(self.tsv_file, self.lineidx) 66 | 67 | if self._lineidx is None: 68 | with open(self.lineidx, 'r') as fp: 69 | lines = fp.readlines() 70 | self._lineidx = [int(i.strip().split()[0]) for i in lines] 71 | 72 | if self.split_label: 73 | with open(self.label_lineidx, 'r') as fp: 74 | lines = fp.readlines() 75 | self._label_lineidx = [int(i.strip().split()[0]) for i in lines] 76 | 77 | 78 | def _ensure_tsv_opened(self): 79 | self._ensure_lineidx_loaded() 80 | if self._fp is None: 81 | self._fp = open(self.tsv_file, 'r') 82 | self.pid = os.getpid() 83 | 84 | if self.split_label: 85 | self._label_fp = open(self.label_file, 'r') 86 | 87 | if self.pid != os.getpid(): 88 | print('re-open {} because the process id changed'.format(self.tsv_file)) 89 | self._fp = open(self.tsv_file, 'r') 90 | self.pid = os.getpid() 91 | 92 | if self.split_label: 93 | self._label_fp = open(self.label_file, 'r') 94 | -------------------------------------------------------------------------------- /promptable_segmentation/demo.sh: -------------------------------------------------------------------------------- 1 | # #interactive gradio demo 2 | CUDA_VISIBLE_DEVICES=4 python demo_promptable.py \ 3 | --ckpt /home/xudongw/UnSAM-Semantic/data/output/gpu4-bs4-lr1e-4-iter100k-12jsons-SSL-AllAnnos-NMask120-Thresh0.02/10k_0099999.pth \ 4 | --conf_files configs/semantic_sam_only_sa-1b_swinT.yaml \ 5 | --device cpu -------------------------------------------------------------------------------- /promptable_segmentation/examples/sa_121371.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/promptable_segmentation/examples/sa_121371.jpg -------------------------------------------------------------------------------- /promptable_segmentation/examples/sa_412497.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/promptable_segmentation/examples/sa_412497.jpg -------------------------------------------------------------------------------- /promptable_segmentation/examples/sa_562217.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/promptable_segmentation/examples/sa_562217.jpg -------------------------------------------------------------------------------- /promptable_segmentation/model/BaseModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from utils.model import align_and_update_state_dicts 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BaseModel(nn.Module): 13 | def __init__(self, opt, module: nn.Module): 14 | super(BaseModel, self).__init__() 15 | self.opt = opt 16 | self.model = module 17 | 18 | def forward(self, *inputs, **kwargs): 19 | outputs = self.model(*inputs, **kwargs) 20 | return outputs 21 | 22 | def save_pretrained(self, save_path): 23 | torch.save(self.model.state_dict(), save_path) 24 | 25 | def from_pretrained(self, load_dir): 26 | state_dict = torch.load(load_dir, map_location='cpu') 27 | if 'model' in state_dict: 28 | state_dict=state_dict['model'] 29 | state_dict={k[6:]:v for k,v in state_dict.items() if k.startswith('model.')} 30 | # for k in self.model.state_dict(): 31 | # if k not in state_dict: 32 | # assert k[:-2] in state_dict 33 | # state_dict[k]=state_dict.pop(k[:-2]) 34 | state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) 35 | self.model.load_state_dict(state_dict, strict=False) 36 | return self -------------------------------------------------------------------------------- /promptable_segmentation/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .architectures import build_model 6 | from .build_semantic_sam import prepare_image, plot_results, build_semantic_sam, SemanticSamAutomaticMaskGenerator, SemanticSAMPredictor, plot_multi_results -------------------------------------------------------------------------------- /promptable_segmentation/model/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_mask_dino import * 2 | from semantic_sam.architectures.build import build_model -------------------------------------------------------------------------------- /promptable_segmentation/model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_sam.backbone.build import build_backbone 2 | 3 | from semantic_sam.backbone.focal import * 4 | from semantic_sam.backbone.focal_dw import * 5 | from .swin import * 6 | from semantic_sam.backbone.backbone import * -------------------------------------------------------------------------------- /promptable_segmentation/model/body/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_semantic_sam_head -------------------------------------------------------------------------------- /promptable_segmentation/model/body/build.py: -------------------------------------------------------------------------------- 1 | from semantic_sam.body.registry import model_entrypoints 2 | from semantic_sam.body.registry import is_model 3 | from .general_head import * 4 | 5 | 6 | def build_semantic_sam_head(config, *args, **kwargs): 7 | model_name = config['MODEL']['HEAD'] 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | body = model_entrypoints(model_name)(config, *args, **kwargs) 12 | return body -------------------------------------------------------------------------------- /promptable_segmentation/model/body/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_decoder 2 | from .interactive_mask_dino import * -------------------------------------------------------------------------------- /promptable_segmentation/model/body/decoder/build.py: -------------------------------------------------------------------------------- 1 | from semantic_sam.body.decoder.registry import model_entrypoints 2 | from semantic_sam.body.decoder.registry import is_model 3 | 4 | 5 | def build_decoder(config, *args, **kwargs): 6 | model_name = config['MODEL']['DECODER']['NAME'] 7 | 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /promptable_segmentation/model/body/decoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_sam.body.decoder.utils import * -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_encoder -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/build.py: -------------------------------------------------------------------------------- 1 | from semantic_sam.body.encoder.registry import model_entrypoints 2 | from semantic_sam.body.encoder.registry import is_model 3 | 4 | from semantic_sam.body.encoder.transformer_encoder_fpn import * 5 | from .encoder_deform import * 6 | 7 | def build_encoder(config, *args, **kwargs): 8 | model_name = config['MODEL']['ENCODER']['NAME'] 9 | 10 | if not is_model(model_name): 11 | raise ValueError(f'Unkown model: {model_name}') 12 | 13 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd mask2former/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install --user 14 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/encoder/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /promptable_segmentation/model/body/general_head.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) MicroSoft, Inc. and its affiliates. 3 | # Modified from DINO https://github.com/IDEA-Research/MaskDINO by Feng Li. 4 | # ------------------------------------------------------------------------ 5 | import logging 6 | from typing import Callable, Dict, List, Optional, Tuple, Union 7 | 8 | from torch import nn 9 | 10 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 11 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 12 | 13 | from semantic_sam.body.registry import register_body 14 | from .encoder import build_encoder 15 | from .decoder import build_decoder 16 | from ..utils import configurable 17 | 18 | 19 | class IMaskDINOHead(nn.Module): 20 | @configurable 21 | def __init__( 22 | self, 23 | input_shape: Dict[str, ShapeSpec], 24 | *, 25 | num_classes: int, 26 | pixel_decoder: nn.Module, 27 | loss_weight: float = 1.0, 28 | ignore_value: int = -1, 29 | transformer_predictor: nn.Module, 30 | ): 31 | """ 32 | Args: 33 | input_shape: shapes (channels and stride) of the input features 34 | num_classes: number of classes to predict 35 | pixel_decoder: the pixel decoder module 36 | loss_weight: loss weight 37 | ignore_value: category id to be ignored during training. 38 | transformer_predictor: the transformer decoder that makes prediction 39 | transformer_in_feature: input feature name to the transformer_predictor 40 | """ 41 | super().__init__() 42 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 43 | self.in_features = [k for k, v in input_shape] 44 | self.ignore_value = ignore_value 45 | self.common_stride = 4 46 | self.loss_weight = loss_weight 47 | 48 | self.pixel_decoder = pixel_decoder 49 | self.predictor = transformer_predictor 50 | 51 | self.num_classes = num_classes 52 | # store processed features 53 | self.processed_features = None 54 | 55 | @classmethod 56 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): 57 | enc_cfg = cfg['MODEL']['ENCODER'] 58 | dec_cfg = cfg['MODEL']['DECODER'] 59 | transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] 60 | 61 | return { 62 | "input_shape": { 63 | k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] 64 | }, 65 | "ignore_value": enc_cfg['IGNORE_VALUE'], 66 | "num_classes": enc_cfg.get('NUM_CLASSES', None), 67 | "pixel_decoder": build_encoder(cfg, input_shape), 68 | "loss_weight": enc_cfg['LOSS_WEIGHT'], 69 | "transformer_predictor": build_decoder( 70 | cfg, 71 | transformer_predictor_in_channels, 72 | lang_encoder, 73 | mask_classification=True, 74 | extra=extra, 75 | ), 76 | } 77 | 78 | def forward_encoder(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): 79 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( 80 | features, mask) 81 | self.processed_features = (mask_features, transformer_encoder_features, multi_scale_features) 82 | 83 | def forward_decoder(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): 84 | assert self.processed_features is not None, "need to precess features first" 85 | mask_features, transformer_encoder_features, multi_scale_features = self.processed_features 86 | if task == 'teacher': 87 | predictions = self.predictor.forward_teacher(multi_scale_features, mask_features, mask, targets=targets, 88 | target_queries=target_queries, target_vlp=target_vlp, 89 | task=task, extra=extra) 90 | else: 91 | predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, 92 | target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) 93 | return predictions 94 | 95 | def forward(self, features, mask=None, targets=None, target_queries=None, target_vlp=None, task='seg', extra={}): 96 | return self.layers(features, mask, targets=targets, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) 97 | 98 | def layers(self, features, mask=None,targets=None, target_queries=None, target_vlp=None, prediction_switch=None, task='seg', extra={}): 99 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features, mask) 100 | predictions = self.predictor(multi_scale_features, mask_features, mask, targets=targets, 101 | target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) 102 | return predictions 103 | 104 | 105 | @register_body 106 | def get_interactive_maskdino_head(cfg, input_shape, lang_encoder, extra): 107 | return IMaskDINOHead(cfg, input_shape, lang_encoder, extra) -------------------------------------------------------------------------------- /promptable_segmentation/model/build_semantic_sam.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Feng Li (fliay@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import matplotlib.pyplot as plt 9 | from PIL import Image 10 | import numpy as np 11 | from torchvision import transforms 12 | import torch 13 | import os 14 | 15 | from utils.arguments import load_opt_from_config_file 16 | from model.BaseModel import BaseModel 17 | from model import build_model 18 | from tasks.automatic_mask_generator import SemanticSamAutomaticMaskGenerator 19 | from tasks.interactive_idino_m2m_auto import show_anns 20 | from tasks.interactive_predictor import SemanticSAMPredictor 21 | 22 | 23 | def prepare_image(image_pth, img_size=640): 24 | """ 25 | apply transformation to the image. crop the image ot 640 short edge by default 26 | """ 27 | image = Image.open(image_pth).convert('RGB') 28 | t = [] 29 | t.append(transforms.Resize(img_size, interpolation=Image.BICUBIC)) 30 | transform1 = transforms.Compose(t) 31 | image_ori = transform1(image) 32 | 33 | image_ori = np.asarray(image_ori) 34 | images = torch.from_numpy(image_ori.copy()).permute(2, 0, 1).cuda() 35 | 36 | return image_ori, images 37 | 38 | 39 | def build_semantic_sam(model_type, ckpt): 40 | """ 41 | build model 42 | """ 43 | cfgs={'T':"configs/semantic_sam_only_sa-1b_swinT.yaml", 44 | 'L':"configs/semantic_sam_only_sa-1b_swinL.yaml"} 45 | 46 | sam_cfg=cfgs[model_type] 47 | opt = load_opt_from_config_file(sam_cfg) 48 | model_semantic_sam = BaseModel(opt, build_model(opt)).from_pretrained(ckpt).eval().cuda() 49 | return model_semantic_sam 50 | 51 | 52 | def plot_results(outputs, image_ori, save_path='../vis/'): 53 | """ 54 | plot input image and its reuslts 55 | """ 56 | if os.path.isdir(save_path): 57 | image_ori_name = 'input.png' 58 | im_name = 'example.png' 59 | else: 60 | image_ori_name = os.path.basename(save_path).split('.')[0] + '_input.png' 61 | im_name = os.path.basename(save_path).split('.')[0]+ '_example.png' 62 | save_path = os.path.dirname(save_path) 63 | 64 | if not os.path.exists(save_path): 65 | os.mkdir(save_path) 66 | 67 | fig = plt.figure() 68 | plt.imshow(image_ori) 69 | plt.savefig(os.path.join(save_path, image_ori_name)) 70 | show_anns(outputs) 71 | fig.canvas.draw() 72 | im = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) 73 | plt.savefig(os.path.join(save_path, im_name)) 74 | return im 75 | 76 | def plot_multi_results(iou_sort_masks, area_sort_masks, image_ori, save_path='../vis/'): 77 | """ 78 | plot input image and its reuslts 79 | """ 80 | if not os.path.exists(save_path): 81 | os.mkdir(save_path) 82 | plt.imshow(image_ori) 83 | plt.savefig('../vis/input.png') 84 | def create_long_image(masks): 85 | ims = [] 86 | for img in masks: 87 | ims.append(img) 88 | width, height = ims[0].size 89 | result = Image.new(ims[0].mode, (width * len(ims), height)) 90 | for i, im in enumerate(ims): 91 | result.paste(im, box=(i * width, 0)) 92 | return result 93 | create_long_image(iou_sort_masks).save('../vis/all_results_sort_by_iou.png') 94 | create_long_image(area_sort_masks).save('../vis/all_results_sort_by_areas.png') -------------------------------------------------------------------------------- /promptable_segmentation/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_sam.utils.config import * 2 | from semantic_sam.utils.misc import * 3 | # from .dist import * -------------------------------------------------------------------------------- /promptable_segmentation/model/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | def box_xywh_to_xyxy(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [x0, y0, (x0 + x1), (y0 + y1)] 25 | return torch.stack(b, dim=-1) 26 | 27 | 28 | # modified from torchvision to also return the union 29 | def box_iou(boxes1, boxes2): 30 | area1 = box_area(boxes1) 31 | area2 = box_area(boxes2) 32 | 33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clamp(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | iou = inter / (union+1e-6) 42 | return iou, union 43 | 44 | 45 | def generalized_box_iou(boxes1, boxes2): 46 | """ 47 | Generalized IoU from https://giou.stanford.edu/ 48 | 49 | The boxes should be in [x0, y0, x1, y1] format 50 | 51 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 52 | and M = len(boxes2) 53 | """ 54 | # degenerate boxes gives inf / nan results 55 | # so do an early check 56 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 57 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 58 | iou, union = box_iou(boxes1, boxes2) 59 | 60 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 61 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 62 | 63 | wh = (rb - lt).clamp(min=0) # [N,M,2] 64 | area = wh[:, :, 0] * wh[:, :, 1] 65 | 66 | return iou - (area - union) / (area+1e-6) 67 | 68 | def generalized_box_iou_padded(boxes1, boxes2): 69 | """ 70 | Generalized IoU from https://giou.stanford.edu/ 71 | 72 | The boxes should be in [x0, y0, x1, y1] format 73 | 74 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 75 | and M = len(boxes2) 76 | """ 77 | # degenerate boxes gives inf / nan results 78 | # so do an early check 79 | # assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 80 | # assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 81 | iou, union = box_iou(boxes1, boxes2) 82 | 83 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 84 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 85 | 86 | wh = (rb - lt).clamp(min=0) # [N,M,2] 87 | area = wh[:, :, 0] * wh[:, :, 1] 88 | 89 | return iou - (area - union) / (area+1e-6) 90 | 91 | 92 | def masks_to_boxes(masks): 93 | """Compute the bounding boxes around the provided masks 94 | 95 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 96 | 97 | Returns a [N, 4] tensors, with the boxes in xyxy format 98 | """ 99 | if masks.numel() == 0: 100 | return torch.zeros((0, 4), device=masks.device) 101 | 102 | h, w = masks.shape[-2:] 103 | 104 | y = torch.arange(0, h, dtype=torch.float) 105 | x = torch.arange(0, w, dtype=torch.float) 106 | y, x = torch.meshgrid(y, x) 107 | 108 | x_mask = (masks * x.unsqueeze(0)) 109 | x_max = x_mask.flatten(1).max(-1)[0] 110 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 111 | 112 | y_mask = (masks * y.unsqueeze(0)) 113 | y_max = y_mask.flatten(1).max(-1)[0] 114 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 115 | 116 | return torch.stack([x_min, y_min, x_max, y_max], 1) -------------------------------------------------------------------------------- /promptable_segmentation/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m 2 | from .interactive_idino_m2m_auto import interactive_infer_image as interactive_infer_image_idino_m2m_auto 3 | from .automatic_mask_generator import prompt_switch 4 | from .interactive_predictor import SemanticSAMPredictor -------------------------------------------------------------------------------- /promptable_segmentation/tasks/interactive_idino_m2m.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Hao Zhang (hzhangcx@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import numpy as np 10 | from torchvision import transforms 11 | from utils.visualizer import Visualizer 12 | from typing import Tuple 13 | from PIL import Image 14 | from detectron2.data import MetadataCatalog 15 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 16 | 17 | def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, device='gpu',refimg=None, reftxt=None, audio_pth=None, video_pth=None): 18 | t = [] 19 | t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) 20 | transform1 = transforms.Compose(t) 21 | image_ori = transform1(image['image']) 22 | mask_ori = transform1(image['mask']) 23 | width = image_ori.size[0] 24 | height = image_ori.size[1] 25 | image_ori = np.asarray(image_ori) 26 | if device == 'cpu': 27 | images = torch.from_numpy(image_ori.copy()).permute(2,0,1) 28 | else: 29 | images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() 30 | all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':') 31 | 32 | 33 | data = {"image": images, "height": height, "width": width} 34 | 35 | mask_ori = np.asarray(mask_ori)[:,:,0:1].copy() 36 | mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0] 37 | points=mask_ori.nonzero().float().to(images.device) 38 | if len(points)==0: 39 | point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]]) 40 | else: 41 | point_=points.mean(0)[None] 42 | point=point_.clone() 43 | point[0, 0] = point_[0, 0] / mask_ori.shape[0] 44 | point[0, 1] = point_[0, 1] / mask_ori.shape[1] 45 | point = point[:, [1, 0]] 46 | point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1) 47 | data['targets'] = [dict()] 48 | data['targets'][0]['points']=point 49 | data['targets'][0]['pb']=point.new_tensor([0.]) 50 | 51 | 52 | batch_inputs = [data] 53 | masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts) 54 | 55 | pred_masks_poses = masks 56 | reses=[] 57 | ious=ious[0,0] 58 | ids=torch.argsort(ious,descending=True) 59 | 60 | text_res='' 61 | try: 62 | thresh=float(thresh) 63 | except Exception: 64 | thresh=0.0 65 | mask_ls=[] 66 | ious_res=[] 67 | areas=[] 68 | for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])): 69 | iou=round(float(iou),2) 70 | texts=f'{iou}' 71 | mask=(pred_masks_pos>0.0).cpu().numpy() 72 | area=mask.sum() 73 | conti=False 74 | if iou0.95: 78 | conti=True 79 | break 80 | if i == len(pred_masks_poses[ids])-1 and mask_ls==[]: 81 | conti=False 82 | if conti: 83 | continue 84 | ious_res.append(iou) 85 | mask_ls.append(mask) 86 | areas.append(area) 87 | mask,_=remove_small_regions(mask,int(hole_scale),mode="holes") 88 | mask,_=remove_small_regions(mask,int(island_scale),mode="islands") 89 | mask=(mask).astype(float) 90 | out_txt = texts 91 | visual = Visualizer(image_ori, metadata=metadata) 92 | color=[0.,0.,1.0] 93 | demo = visual.draw_binary_mask(mask, color=color, text=texts) 94 | res = demo.get_image() 95 | point_x0=max(0,int(point_[0, 1])-3) 96 | point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3) 97 | point_y0 = max(0, int(point_[0, 0]) - 3) 98 | point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3) 99 | res[point_y0:point_y1,point_x0:point_x1,0]=255 100 | res[point_y0:point_y1,point_x0:point_x1,1]=0 101 | res[point_y0:point_y1,point_x0:point_x1,2]=0 102 | reses.append(Image.fromarray(res)) 103 | text_res=text_res+';'+out_txt 104 | ids=list(torch.argsort(torch.tensor(areas),descending=False)) 105 | ids = [int(i) for i in ids] 106 | 107 | torch.cuda.empty_cache() 108 | 109 | return reses,[reses[i] for i in ids] 110 | 111 | def remove_small_regions( 112 | mask: np.ndarray, area_thresh: float, mode: str 113 | ) -> Tuple[np.ndarray, bool]: 114 | """ 115 | Removes small disconnected regions and holes in a mask. Returns the 116 | mask and an indicator of if the mask has been modified. 117 | """ 118 | import cv2 # type: ignore 119 | 120 | assert mode in ["holes", "islands"] 121 | correct_holes = mode == "holes" 122 | working_mask = (correct_holes ^ mask).astype(np.uint8) 123 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 124 | sizes = stats[:, -1][1:] # Row 0 is background label 125 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 126 | if len(small_regions) == 0: 127 | return mask, False 128 | fill_labels = [0] + small_regions 129 | if not correct_holes: 130 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 131 | # If every region is below threshold, keep largest 132 | if len(fill_labels) == 0: 133 | fill_labels = [int(np.argmax(sizes)) + 1] 134 | mask = np.isin(regions, fill_labels) 135 | return mask, True -------------------------------------------------------------------------------- /promptable_segmentation/tasks/interactive_idino_m2m_auto.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Hao Zhang (hzhangcx@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import numpy as np 10 | from torchvision import transforms 11 | from utils.visualizer import Visualizer 12 | from typing import Tuple 13 | from PIL import Image 14 | from detectron2.data import MetadataCatalog 15 | import matplotlib.pyplot as plt 16 | import cv2 17 | import io 18 | from .automatic_mask_generator import SemanticSamAutomaticMaskGenerator 19 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 20 | 21 | def interactive_infer_image(model, image,level,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None): 22 | t = [] 23 | t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) 24 | transform1 = transforms.Compose(t) 25 | image_ori = transform1(image) 26 | 27 | image_ori = np.asarray(image_ori) 28 | images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() 29 | 30 | mask_generator = SemanticSamAutomaticMaskGenerator(model,points_per_side=80, # 32 31 | pred_iou_thresh=0.7, # 0.88, eval-0.35 32 | stability_score_thresh=0.7, # 0.92, eval-0.05 33 | min_mask_region_area=15, # 10, eval-15 34 | level=level, 35 | ) 36 | 37 | outputs = mask_generator.generate(images) 38 | 39 | fig=plt.figure(figsize=(10, 10)) 40 | plt.imshow(image_ori) 41 | show_anns(outputs) 42 | fig.canvas.draw() 43 | im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) 44 | return im 45 | 46 | 47 | def remove_small_regions( 48 | mask: np.ndarray, area_thresh: float, mode: str 49 | ) -> Tuple[np.ndarray, bool]: 50 | """ 51 | Removes small disconnected regions and holes in a mask. Returns the 52 | mask and an indicator of if the mask has been modified. 53 | """ 54 | import cv2 # type: ignore 55 | 56 | assert mode in ["holes", "islands"] 57 | correct_holes = mode == "holes" 58 | working_mask = (correct_holes ^ mask).astype(np.uint8) 59 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 60 | sizes = stats[:, -1][1:] # Row 0 is background label 61 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 62 | if len(small_regions) == 0: 63 | return mask, False 64 | fill_labels = [0] + small_regions 65 | if not correct_holes: 66 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 67 | # If every region is below threshold, keep largest 68 | if len(fill_labels) == 0: 69 | fill_labels = [int(np.argmax(sizes)) + 1] 70 | mask = np.isin(regions, fill_labels) 71 | return mask, True 72 | 73 | def show_anns(anns): 74 | if len(anns) == 0: 75 | return 76 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 77 | ax = plt.gca() 78 | ax.set_autoscale_on(False) 79 | polygons = [] 80 | color = [] 81 | for ann in sorted_anns: 82 | m = ann['segmentation'] 83 | img = np.ones((m.shape[0], m.shape[1], 3)) 84 | color_mask = np.random.random((1, 3)).tolist()[0] 85 | for i in range(3): 86 | img[:,:,i] = color_mask[i] 87 | ax.imshow(np.dstack((img, m*0.35))) -------------------------------------------------------------------------------- /promptable_segmentation/tasks/interactive_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision import transforms 4 | from utils.visualizer import Visualizer 5 | from typing import Tuple 6 | from PIL import Image 7 | from detectron2.data import MetadataCatalog 8 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 9 | 10 | 11 | class SemanticSAMPredictor: 12 | def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100): 13 | """ 14 | thresh: iou thresh to filter low confidence objects 15 | text_size: resize the input image short edge for the model to process 16 | hole_scale: fill in small holes as in SAM 17 | island_scale: remove small regions as in SAM 18 | """ 19 | self.model = model 20 | self.thresh = thresh 21 | self.text_size = hole_scale 22 | self.hole_scale = hole_scale 23 | self.island_scale = island_scale 24 | self.point = None 25 | 26 | def predict(self, image_ori, image, point=None): 27 | """ 28 | produce up to 6 prediction results for each click 29 | """ 30 | width = image_ori.shape[1] 31 | height = image_ori.shape[0] 32 | 33 | data = {"image": image, "height": height, "width": width} 34 | # import ipdb; ipdb.set_trace() 35 | if point is None: 36 | point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda() 37 | else: 38 | point = torch.tensor(point).cuda() 39 | point_ = point 40 | point = point_.clone() 41 | point[0, 0] = point_[0, 0] 42 | point[0, 1] = point_[0, 1] 43 | # point = point[:, [1, 0]] 44 | point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1) 45 | 46 | self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point)) 47 | 48 | data['targets'] = [dict()] 49 | data['targets'][0]['points'] = point 50 | data['targets'][0]['pb'] = point.new_tensor([0.]) 51 | 52 | batch_inputs = [data] 53 | masks, ious = self.model.model.evaluate_demo(batch_inputs) 54 | 55 | return masks, ious 56 | 57 | def process_multi_mask(self, masks, ious, image_ori): 58 | pred_masks_poses = masks 59 | reses = [] 60 | ious = ious[0, 0] 61 | ids = torch.argsort(ious, descending=True) 62 | 63 | text_res = '' 64 | mask_ls = [] 65 | ious_res = [] 66 | areas = [] 67 | for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])): 68 | iou = round(float(iou), 2) 69 | texts = f'{iou}' 70 | mask = (pred_masks_pos > 0.0).cpu().numpy() 71 | area = mask.sum() 72 | conti = False 73 | if iou < self.thresh: 74 | conti = True 75 | for m in mask_ls: 76 | if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95: 77 | conti = True 78 | break 79 | if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []: 80 | conti = False 81 | if conti: 82 | continue 83 | ious_res.append(iou) 84 | mask_ls.append(mask) 85 | areas.append(area) 86 | mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes") 87 | mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands") 88 | mask = (mask).astype(float) 89 | out_txt = texts 90 | visual = Visualizer(image_ori, metadata=metadata) 91 | color = [0., 0., 1.0] 92 | demo = visual.draw_binary_mask(mask, color=color, text=texts) 93 | res = demo.get_image() 94 | point_x0 = max(0, int(self.point[0, 0]) - 3) 95 | point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3) 96 | point_y0 = max(0, int(self.point[0, 1]) - 3) 97 | point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3) 98 | res[point_y0:point_y1, point_x0:point_x1, 0] = 255 99 | res[point_y0:point_y1, point_x0:point_x1, 1] = 0 100 | res[point_y0:point_y1, point_x0:point_x1, 2] = 0 101 | reses.append(Image.fromarray(res)) 102 | text_res = text_res + ';' + out_txt 103 | ids = list(torch.argsort(torch.tensor(areas), descending=False)) 104 | ids = [int(i) for i in ids] 105 | 106 | torch.cuda.empty_cache() 107 | 108 | return reses, [reses[i] for i in ids] 109 | 110 | def predict_masks(self, image_ori, image, point=None): 111 | masks, ious = self.predict(image_ori, image, point) 112 | return self.process_multi_mask(masks, ious, image_ori) 113 | 114 | @staticmethod 115 | def remove_small_regions( 116 | mask: np.ndarray, area_thresh: float, mode: str 117 | ) -> Tuple[np.ndarray, bool]: 118 | """ 119 | Removes small disconnected regions and holes in a mask. Returns the 120 | mask and an indicator of if the mask has been modified. 121 | """ 122 | import cv2 # type: ignore 123 | 124 | assert mode in ["holes", "islands"] 125 | correct_holes = mode == "holes" 126 | working_mask = (correct_holes ^ mask).astype(np.uint8) 127 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 128 | sizes = stats[:, -1][1:] # Row 0 is background label 129 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 130 | if len(small_regions) == 0: 131 | return mask, False 132 | fill_labels = [0] + small_regions 133 | if not correct_holes: 134 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 135 | # If every region is below threshold, keep largest 136 | if len(fill_labels) == 0: 137 | fill_labels = [int(np.argmax(sizes)) + 1] 138 | mask = np.isin(regions, fill_labels) 139 | return mask, True 140 | -------------------------------------------------------------------------------- /promptable_segmentation/train.sh: -------------------------------------------------------------------------------- 1 | export TRAIN_DATASETS=/scratch/one_month/2024_05/xudongw/SAM_4perc/gt_merged_threshold0.02 2 | 3 | run_name="gpu4-bs4-lr1e-4-iter400k-42jsons-gt" 4 | ngpus=4 5 | CUDA_VISIBLE_DEVICES=0,1,2,3, python train_net.py \ 6 | --num-gpus ${ngpus} \ 7 | --resume \ 8 | --config-file configs/semantic_sam_only_sa-1b_swinT.yaml \ 9 | SOLVER.BASE_LR=1e-4 \ 10 | COCO.TEST.BATCH_SIZE_TOTAL=${ngpus} \ 11 | SAM.TEST.BATCH_SIZE_TOTAL=${ngpus} \ 12 | SAM.TRAIN.BATCH_SIZE_TOTAL=${ngpus} \ 13 | TEST.EVAL_PERIOD=400000 \ 14 | OUTPUT_DIR=data/output/${run_name} \ 15 | 16 | -------------------------------------------------------------------------------- /promptable_segmentation/utils/Config.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode as _CfgNode 2 | 3 | class CfgNode(_CfgNode): 4 | """ 5 | The same as `fvcore.common.config.CfgNode`, but different in: 6 | 7 | 1. Use unsafe yaml loading by default. 8 | Note that this may lead to arbitrary code execution: you must not 9 | load a config file from untrusted sources before manually inspecting 10 | the content of the file. 11 | 2. Support config versioning. 12 | When attempting to merge an old config, it will convert the old config automatically. 13 | 14 | .. automethod:: clone 15 | .. automethod:: freeze 16 | .. automethod:: defrost 17 | .. automethod:: is_frozen 18 | .. automethod:: load_yaml_with_base 19 | .. automethod:: merge_from_list 20 | .. automethod:: merge_from_other_cfg 21 | """ 22 | 23 | def merge_from_dict(self, dict): 24 | pass 25 | 26 | node = CfgNode() -------------------------------------------------------------------------------- /promptable_segmentation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist import * -------------------------------------------------------------------------------- /promptable_segmentation/utils/arguments.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import argparse 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def load_config_dict_to_opt(opt, config_dict): 10 | """ 11 | Load the key, value pairs from config_dict to opt, overriding existing values in opt 12 | if there is any. 13 | """ 14 | if not isinstance(config_dict, dict): 15 | raise TypeError("Config must be a Python dictionary") 16 | for k, v in config_dict.items(): 17 | k_parts = k.split('.') 18 | pointer = opt 19 | for k_part in k_parts[:-1]: 20 | if k_part not in pointer: 21 | pointer[k_part] = {} 22 | pointer = pointer[k_part] 23 | assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." 24 | ori_value = pointer.get(k_parts[-1]) 25 | pointer[k_parts[-1]] = v 26 | if ori_value: 27 | logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") 28 | 29 | def load_opt_from_config_file(conf_file): 30 | """ 31 | Load opt from the config files, settings in later files can override those in previous files. 32 | 33 | Args: 34 | conf_files: config file path 35 | 36 | Returns: 37 | dict: a dictionary of opt settings 38 | """ 39 | opt = {} 40 | with open(conf_file, encoding='utf-8') as f: 41 | config_dict = yaml.safe_load(f) 42 | 43 | load_config_dict_to_opt(opt, config_dict) 44 | 45 | return opt 46 | 47 | def load_opt_from_config_files(conf_files): 48 | """ 49 | Load opt from the config files, settings in later files can override those in previous files. 50 | 51 | Args: 52 | conf_files (list): a list of config file paths 53 | 54 | Returns: 55 | dict: a dictionary of opt settings 56 | """ 57 | opt = {} 58 | for conf_file in conf_files: 59 | with open(conf_file, encoding='utf-8') as f: 60 | config_dict = yaml.safe_load(f) 61 | 62 | load_config_dict_to_opt(opt, config_dict) 63 | 64 | return opt 65 | 66 | 67 | def load_opt_command(args): 68 | parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.') 69 | parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') 70 | parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).') 71 | parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.') 72 | parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') 73 | parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER) 74 | 75 | cmdline_args = parser.parse_args() if not args else parser.parse_args(args) 76 | 77 | opt = load_opt_from_config_files(cmdline_args.conf_files) 78 | 79 | if cmdline_args.config_overrides: 80 | config_overrides_string = ' '.join(cmdline_args.config_overrides) 81 | logger.warning(f"Command line config overrides: {config_overrides_string}") 82 | config_dict = json.loads(config_overrides_string) 83 | load_config_dict_to_opt(opt, config_dict) 84 | 85 | if cmdline_args.overrides: 86 | assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value" 87 | keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] 88 | vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] 89 | vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] 90 | 91 | types = [] 92 | for key in keys: 93 | key = key.split('.') 94 | ele = opt.copy() 95 | while len(key) > 0: 96 | ele = ele[key.pop(0)] 97 | types.append(type(ele)) 98 | 99 | config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} 100 | load_config_dict_to_opt(opt, config_dict) 101 | 102 | # combine cmdline_args into opt dictionary 103 | for key, val in cmdline_args.__dict__.items(): 104 | if val is not None: 105 | opt[key] = val 106 | 107 | return opt, cmdline_args 108 | 109 | 110 | def save_opt_to_json(opt, conf_file): 111 | with open(conf_file, 'w', encoding='utf-8') as f: 112 | json.dump(opt, f, indent=4) 113 | 114 | 115 | def save_opt_to_yaml(opt, conf_file): 116 | with open(conf_file, 'w', encoding='utf-8') as f: 117 | yaml.dump(opt, f) 118 | -------------------------------------------------------------------------------- /promptable_segmentation/utils/dist.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import os 4 | import random 5 | import subprocess 6 | import time 7 | from collections import OrderedDict, defaultdict, deque 8 | import datetime 9 | import pickle 10 | from typing import Optional, List 11 | 12 | import json, time 13 | import numpy as np 14 | import torch 15 | import torch.distributed as dist 16 | from torch import Tensor 17 | 18 | import colorsys 19 | def init_distributed_mode(args): 20 | if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and 21 | args.rank = int(os.environ["RANK"]) 22 | args.world_size = int(os.environ['WORLD_SIZE']) 23 | args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) 24 | 25 | # launch by torch.distributed.launch 26 | # Single node 27 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... 28 | # Multi nodes 29 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... 30 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... 31 | # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) 32 | # local_world_size = int(os.environ['GPU_PER_NODE_COUNT']) 33 | # args.world_size = args.world_size * local_world_size 34 | # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) 35 | # args.rank = args.rank * local_world_size + args.local_rank 36 | print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) 37 | print(json.dumps(dict(os.environ), indent=2)) 38 | elif 'SLURM_PROCID' in os.environ: 39 | args.rank = int(os.environ['SLURM_PROCID']) 40 | args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) 41 | args.world_size = int(os.environ['SLURM_NPROCS']) 42 | 43 | if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1': 44 | pass 45 | else: 46 | import util.hostlist as uh 47 | nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST']) 48 | gpu_ids = [int(node[3:]) for node in nodenames] 49 | fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0)) 50 | # fixid += random.randint(0, 300) 51 | port = str(3137 + int(min(gpu_ids)) + fixid) 52 | args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port) 53 | 54 | print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count())) 55 | 56 | 57 | else: 58 | print('Not using distributed mode') 59 | args.distributed = False 60 | args.world_size = 1 61 | args.rank = 0 62 | args.local_rank = 0 63 | return 64 | 65 | print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) 66 | args.distributed = True 67 | torch.cuda.set_device(args.local_rank) 68 | args.dist_backend = 'nccl' 69 | print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) 70 | 71 | torch.distributed.init_process_group( 72 | backend=args.dist_backend, 73 | world_size=args.world_size, 74 | rank=args.rank, 75 | init_method=args.dist_url, 76 | ) 77 | 78 | print("Before torch.distributed.barrier()") 79 | torch.distributed.barrier() 80 | print("End torch.distributed.barrier()") -------------------------------------------------------------------------------- /promptable_segmentation/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import time 3 | # import torch 4 | # import pickle 5 | # import subprocess 6 | 7 | # from mpi4py import MPI 8 | # import torch.distributed as dist 9 | 10 | 11 | # def apply_distributed(opt): 12 | # if opt['rank'] == 0: 13 | # hostname_cmd = ["hostname -I"] 14 | # result = subprocess.check_output(hostname_cmd, shell=True) 15 | # master_address = result.decode('utf-8').split()[0] 16 | # master_port = opt['PORT'] 17 | # else: 18 | # master_address = None 19 | # master_port = None 20 | 21 | # master_address = MPI.COMM_WORLD.bcast(master_address, root=0) 22 | # master_port = MPI.COMM_WORLD.bcast(master_port, root=0) 23 | 24 | # if torch.distributed.is_available() and opt['world_size'] > 1: 25 | # init_method_url = 'tcp://{}:{}'.format(master_address, master_port) 26 | # backend = 'nccl' 27 | # world_size = opt['world_size'] 28 | # rank = opt['rank'] 29 | # torch.distributed.init_process_group(backend=backend, 30 | # init_method=init_method_url, 31 | # world_size=world_size, 32 | # rank=rank) 33 | 34 | # def init_distributed(opt): 35 | # opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() 36 | # if 'OMPI_COMM_WORLD_SIZE' not in os.environ: 37 | # # application was started without MPI 38 | # # default to single node with single process 39 | # opt['env_info'] = 'no MPI' 40 | # opt['world_size'] = 1 41 | # opt['local_size'] = 1 42 | # opt['rank'] = 0 43 | # opt['local_rank'] = 0 44 | # opt['master_address'] = '127.0.0.1' 45 | # opt['master_port'] = '8673' 46 | # else: 47 | # # application was started with MPI 48 | # # get MPI parameters 49 | # opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) 50 | # opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) 51 | # opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) 52 | # opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 53 | 54 | # # set up device 55 | # if not opt['CUDA']: 56 | # assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend' 57 | # opt['device'] = torch.device("cpu") 58 | # else: 59 | # torch.cuda.set_device(opt['local_rank']) 60 | # opt['device'] = torch.device("cuda", opt['local_rank']) 61 | 62 | # apply_distributed(opt) 63 | # return opt 64 | 65 | # def is_main_process(): 66 | # rank = 0 67 | # if 'OMPI_COMM_WORLD_SIZE' in os.environ: 68 | # rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 69 | 70 | # return rank == 0 71 | 72 | # def get_world_size(): 73 | # if not dist.is_available(): 74 | # return 1 75 | # if not dist.is_initialized(): 76 | # return 1 77 | # return dist.get_world_size() 78 | 79 | # def get_rank(): 80 | # if not dist.is_available(): 81 | # return 0 82 | # if not dist.is_initialized(): 83 | # return 0 84 | # return dist.get_rank() 85 | 86 | 87 | # def synchronize(): 88 | # """ 89 | # Helper function to synchronize (barrier) among all processes when 90 | # using distributed training 91 | # """ 92 | # if not dist.is_available(): 93 | # return 94 | # if not dist.is_initialized(): 95 | # return 96 | # world_size = dist.get_world_size() 97 | # rank = dist.get_rank() 98 | # if world_size == 1: 99 | # return 100 | 101 | # def _send_and_wait(r): 102 | # if rank == r: 103 | # tensor = torch.tensor(0, device="cuda") 104 | # else: 105 | # tensor = torch.tensor(1, device="cuda") 106 | # dist.broadcast(tensor, r) 107 | # while tensor.item() == 1: 108 | # time.sleep(1) 109 | 110 | # _send_and_wait(0) 111 | # # now sync on the main process 112 | # _send_and_wait(1) -------------------------------------------------------------------------------- /promptable_segmentation/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import math 8 | 9 | 10 | # HACK for evalution 11 | def hook_metadata(metadata, name): 12 | if name == 'cityscapes_fine_sem_seg_val': 13 | metadata.__setattr__("keep_sem_bgd", False) 14 | return metadata 15 | 16 | def hook_opt(model, name): 17 | if name in ['cityscapes_fine_panoptic_val', 'ade20k_panoptic_val', 'bdd10k_40_panoptic_val', 'cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val']: 18 | model.model.object_mask_threshold = 0.4 19 | else: 20 | model.model.object_mask_threshold = 0.8 21 | 22 | # HACK for evalution 23 | def hook_switcher(model, name): 24 | mappings = {} 25 | if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']: 26 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False} 27 | elif name in ['cityscapes_fine_instance_seg_val', 'pascal_part_val_interactive', 'pascal_part_val', 'pascal_part_train'] or 'seginw' in name: 28 | mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False} 29 | elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']: 30 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True} 31 | elif 'coco_2017_val_panoptic_with_sem_seg' in name or name in ['ade20k_panoptic_val', 'coco_2017_test-dev', 'sam_val', 'sam_minival']: 32 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True} 33 | else: 34 | if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd"]: 35 | assert False, "dataset switcher is not defined" 36 | for key, value in mappings.items(): 37 | if key == 'SEMANTIC_ON': 38 | model.model.semantic_on = value 39 | if key == 'INSTANCE_ON': 40 | model.model.instance_on = value 41 | if key == 'PANOPTIC_ON': 42 | model.model.panoptic_on = value 43 | 44 | class AverageMeter(object): 45 | """Computes and stores the average and current value.""" 46 | def __init__(self): 47 | self.reset() 48 | 49 | def reset(self): 50 | self.val = 0 51 | self.avg = 0 52 | self.sum = 0 53 | self.count = 0 54 | 55 | def update(self, val, n=1, decay=0): 56 | self.val = val 57 | if decay: 58 | alpha = math.exp(-n / decay) # exponential decay over 100 updates 59 | self.sum = alpha * self.sum + (1 - alpha) * val * n 60 | self.count = alpha * self.count + (1 - alpha) * n 61 | else: 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | -------------------------------------------------------------------------------- /promptable_segmentation/utils/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import pickle 5 | import torch 6 | from detectron2.utils.comm import is_main_process 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | NORM_MODULES = [ 12 | torch.nn.BatchNorm1d, 13 | torch.nn.BatchNorm2d, 14 | torch.nn.BatchNorm3d, 15 | torch.nn.SyncBatchNorm, 16 | # NaiveSyncBatchNorm inherits from BatchNorm2d 17 | torch.nn.GroupNorm, 18 | torch.nn.InstanceNorm1d, 19 | torch.nn.InstanceNorm2d, 20 | torch.nn.InstanceNorm3d, 21 | torch.nn.LayerNorm, 22 | torch.nn.LocalResponseNorm, 23 | ] 24 | 25 | def register_norm_module(cls): 26 | NORM_MODULES.append(cls) 27 | return cls 28 | 29 | def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): 30 | model_keys = sorted(model_state_dict.keys()) 31 | ckpt_keys = sorted(ckpt_state_dict.keys()) 32 | result_dicts = {} 33 | matched_log = [] 34 | unmatched_log = [] 35 | unloaded_log = [] 36 | for model_key in model_keys: 37 | model_weight = model_state_dict[model_key] 38 | if model_key in ckpt_keys: 39 | ckpt_weight = ckpt_state_dict[model_key] 40 | if model_weight.shape == ckpt_weight.shape: 41 | result_dicts[model_key] = ckpt_weight 42 | ckpt_keys.pop(ckpt_keys.index(model_key)) 43 | matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 44 | else: 45 | unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 46 | else: 47 | unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) 48 | 49 | if is_main_process(): 50 | for info in matched_log: 51 | logger.info(info) 52 | for info in unloaded_log: 53 | logger.warning(info) 54 | for key in ckpt_keys: 55 | logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) 56 | for info in unmatched_log: 57 | logger.warning(info) 58 | return result_dicts -------------------------------------------------------------------------------- /promptable_segmentation/utils/prompt_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_prompt_templates(): 5 | prompt_templates = [ 6 | '{}.', 7 | 'a photo of a {}.', 8 | 'a bad photo of a {}.', 9 | 'a photo of many {}.', 10 | 'a sculpture of a {}.', 11 | 'a photo of the hard to see {}.', 12 | 'a low resolution photo of the {}.', 13 | 'a rendering of a {}.', 14 | 'graffiti of a {}.', 15 | 'a bad photo of the {}.', 16 | 'a cropped photo of the {}.', 17 | 'a tattoo of a {}.', 18 | 'the embroidered {}.', 19 | 'a photo of a hard to see {}.', 20 | 'a bright photo of a {}.', 21 | 'a photo of a clean {}.', 22 | 'a photo of a dirty {}.', 23 | 'a dark photo of the {}.', 24 | 'a drawing of a {}.', 25 | 'a photo of my {}.', 26 | 'the plastic {}.', 27 | 'a photo of the cool {}.', 28 | 'a close-up photo of a {}.', 29 | 'a black and white photo of the {}.', 30 | 'a painting of the {}.', 31 | 'a painting of a {}.', 32 | 'a pixelated photo of the {}.', 33 | 'a sculpture of the {}.', 34 | 'a bright photo of the {}.', 35 | 'a cropped photo of a {}.', 36 | 'a plastic {}.', 37 | 'a photo of the dirty {}.', 38 | 'a jpeg corrupted photo of a {}.', 39 | 'a blurry photo of the {}.', 40 | 'a photo of the {}.', 41 | 'a good photo of the {}.', 42 | 'a rendering of the {}.', 43 | 'a {} in a video game.', 44 | 'a photo of one {}.', 45 | 'a doodle of a {}.', 46 | 'a close-up photo of the {}.', 47 | 'the origami {}.', 48 | 'the {} in a video game.', 49 | 'a sketch of a {}.', 50 | 'a doodle of the {}.', 51 | 'a origami {}.', 52 | 'a low resolution photo of a {}.', 53 | 'the toy {}.', 54 | 'a rendition of the {}.', 55 | 'a photo of the clean {}.', 56 | 'a photo of a large {}.', 57 | 'a rendition of a {}.', 58 | 'a photo of a nice {}.', 59 | 'a photo of a weird {}.', 60 | 'a blurry photo of a {}.', 61 | 'a cartoon {}.', 62 | 'art of a {}.', 63 | 'a sketch of the {}.', 64 | 'a embroidered {}.', 65 | 'a pixelated photo of a {}.', 66 | 'itap of the {}.', 67 | 'a jpeg corrupted photo of the {}.', 68 | 'a good photo of a {}.', 69 | 'a plushie {}.', 70 | 'a photo of the nice {}.', 71 | 'a photo of the small {}.', 72 | 'a photo of the weird {}.', 73 | 'the cartoon {}.', 74 | 'art of the {}.', 75 | 'a drawing of the {}.', 76 | 'a photo of the large {}.', 77 | 'a black and white photo of a {}.', 78 | 'the plushie {}.', 79 | 'a dark photo of a {}.', 80 | 'itap of a {}.', 81 | 'graffiti of the {}.', 82 | 'a toy {}.', 83 | 'itap of my {}.', 84 | 'a photo of a cool {}.', 85 | 'a photo of a small {}.', 86 | 'a tattoo of the {}.', 87 | ] 88 | return prompt_templates 89 | 90 | def prompt_engineering(classnames, topk=1, suffix='.'): 91 | prompt_templates = get_prompt_templates() 92 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 93 | 94 | if isinstance(classnames, list): 95 | classname = random.choice(classnames) 96 | else: 97 | classname = classnames 98 | 99 | return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) -------------------------------------------------------------------------------- /promptable_segmentation/utils/sam_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /promptable_segmentation/utils/sam_utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /promptable_segmentation/utils/sam_utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | segmentation-refinement 3 | colored 4 | progressbar 5 | h5py 6 | submitit 7 | timm 8 | pillow==9.5.0 9 | -------------------------------------------------------------------------------- /tools/promptable_eval.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=/PATH/TO/YOUR/DATASETS 2 | CUDA_VISIBLE_DEVICES=4, python promptable_segmentation/train_net.py \ 3 | --eval_only \ 4 | --num-gpus 1 \ 5 | --config-file promptable_segmentation/configs/semantic_sam_only_sa-1b_swinT.yaml \ 6 | COCO.TEST.BATCH_SIZE_TOTAL=1 \ 7 | MODEL.WEIGHTS=/home/xudongw/UnSAM-Semantic/data/output/gpu4-bs4-lr1e-4-iter100k-12jsons-SSL-AllAnnos-NMask120-Thresh0.02/10k_0099999.pth \ 8 | OUTPUT_DIR=eval_output \ -------------------------------------------------------------------------------- /tools/whole_image_eval.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=/PATH/TO/YOUR/DATASETS 2 | export TRAIN_DATASETS=/PATH/TO/YOUR/SA-1B 3 | CUDA_VISIBLE_DEVICES=5,6,7,8, python whole_image_segmentation/train_net.py \ 4 | --num-gpus 4 \ 5 | --config-file whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml \ 6 | --eval-only \ 7 | MODEL.WEIGHTS /home/xudongw/mask2former/output/bs16_lr5e-5_rn50_41json_500masks_2000q_DINO/model_0199999.pth \ 8 | SOLVER.IMS_PER_BATCH 4 \ 9 | DATALOADER.NUM_WORKERS 1 \ 10 | OUTPUT_DIR eval_output \ -------------------------------------------------------------------------------- /whole_image_segmentation/configs/Base-COCO-InstanceSegmentation.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl #"detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("unsam_sa1b_train",) 18 | TEST: ("unsam_sa1b_val",) 19 | SELF_TRAIN: false 20 | SELF_TRAIN_NUMBER: 1 21 | SELF_TRAIN_THRESH_LOW: 0.0 22 | SELF_TRAIN_THRESH_HIGH: 1.0 23 | SOLVER: 24 | DROPLOSS_IOU_THRESH: 0.01 25 | USE_DROP_LOSSES: false 26 | IMS_PER_BATCH: 4 27 | BASE_LR: 1e-6 #5e-5 28 | STEPS: (177776, 192592) #(327778, 355092) 29 | MAX_ITER: 200000 #368750 30 | WARMUP_FACTOR: 1.0 31 | WARMUP_ITERS: 10 32 | WEIGHT_DECAY: 0.05 33 | OPTIMIZER: "ADAMW" 34 | BACKBONE_MULTIPLIER: 0.1 35 | CLIP_GRADIENTS: 36 | ENABLED: True 37 | CLIP_TYPE: "full_model" 38 | CLIP_VALUE: 0.01 39 | NORM_TYPE: 2.0 40 | AMP: 41 | ENABLED: True 42 | INPUT: 43 | IMAGE_SIZE: 1024 44 | MIN_SCALE: 1.0 #0.1 45 | MAX_SCALE: 1.0 #2.0 46 | MIN_SIZE_TEST: 1024 47 | MAX_SIZE_TEST: 2048 48 | FORMAT: "RGB" 49 | DATASET_MAPPER_NAME: "sam_instance_tsv" 50 | TEST: 51 | EVAL_PERIOD: 0 #5000 52 | DETECTIONS_PER_IMAGE: 1000 53 | DATALOADER: 54 | FILTER_EMPTY_ANNOTATIONS: True 55 | NUM_WORKERS: 1 56 | VERSION: 2 57 | -------------------------------------------------------------------------------- /whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-COCO-InstanceSegmentation.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "MaskFormerHead" 6 | IGNORE_VALUE: 255 7 | NUM_CLASSES: 1 #80 8 | LOSS_WEIGHT: 1.0 9 | CONVS_DIM: 256 10 | MASK_DIM: 256 11 | NORM: "GN" 12 | # pixel decoder 13 | PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder" 14 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 15 | DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"] 16 | COMMON_STRIDE: 4 17 | TRANSFORMER_ENC_LAYERS: 6 18 | MASK_FORMER: 19 | TRANSFORMER_DECODER_NAME: "MultiScaleMaskedTransformerDecoder" 20 | TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" 21 | DEEP_SUPERVISION: True 22 | NO_OBJECT_WEIGHT: 0.1 23 | CLASS_WEIGHT: 2.0 24 | MASK_WEIGHT: 5.0 25 | DICE_WEIGHT: 5.0 26 | HIDDEN_DIM: 256 27 | NUM_OBJECT_QUERIES: 2000 #100 28 | NHEADS: 8 29 | DROPOUT: 0.0 30 | DIM_FEEDFORWARD: 2048 31 | ENC_LAYERS: 0 32 | PRE_NORM: False 33 | ENFORCE_INPUT_PROJ: False 34 | SIZE_DIVISIBILITY: 32 35 | DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query 36 | TRAIN_NUM_POINTS: 12544 37 | OVERSAMPLE_RATIO: 3.0 38 | IMPORTANCE_SAMPLE_RATIO: 0.75 39 | TEST: 40 | SEMANTIC_ON: False 41 | INSTANCE_ON: True 42 | PANOPTIC_ON: False 43 | OVERLAP_THRESHOLD: 0.8 44 | OBJECT_MASK_THRESHOLD: 0.8 45 | -------------------------------------------------------------------------------- /whole_image_segmentation/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .build import ( 3 | build_batch_data_loader, 4 | build_detection_test_loader, 5 | build_detection_train_loader, 6 | load_proposals_into_dataset, 7 | print_instances_class_histogram, 8 | ) 9 | from detectron2.data.catalog import DatasetCatalog, MetadataCatalog, Metadata 10 | from detectron2.data.common import DatasetFromList, MapDataset, ToIterableDataset 11 | from detectron2.data.dataset_mapper import DatasetMapper 12 | 13 | # ensure the builtin datasets are registered 14 | from . import datasets 15 | from detectron2.data import samplers # isort:skip 16 | 17 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 18 | 19 | -------------------------------------------------------------------------------- /whole_image_segmentation/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import builtin as _builtin # ensure the builtin datasets are registered 3 | 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /whole_image_segmentation/data/datasets/builtin_meta.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | """ 5 | Note: 6 | For your custom dataset, there is no need to hard-code metadata anywhere in the code. 7 | For example, for COCO-format dataset, metadata will be obtained automatically 8 | when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways 9 | during loading. 10 | 11 | However, we hard-coded metadata for a few common dataset here. 12 | The only goal is to allow users who don't have these dataset to use pre-trained models. 13 | Users don't have to download a COCO json (which contains metadata), in order to visualize a 14 | COCO model (with correct class names and colors). 15 | """ 16 | SA1B_CATEGORIES = [ 17 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "fg"}, 18 | ] 19 | 20 | def _get_sa1b_instances_meta(): 21 | thing_ids = [k["id"] for k in SA1B_CATEGORIES if k["isthing"] == 1] 22 | thing_colors = [k["color"] for k in SA1B_CATEGORIES if k["isthing"] == 1] 23 | assert len(thing_ids) == 1, len(thing_ids) 24 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 25 | thing_classes = [k["name"] for k in SA1B_CATEGORIES if k["isthing"] == 1] 26 | ret = { 27 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 28 | "thing_classes": thing_classes, 29 | "thing_colors": thing_colors, 30 | "class_image_count": [{'id': 1, 'image_count': 116986}] 31 | } 32 | return ret 33 | 34 | 35 | def _get_builtin_metadata(dataset_name): 36 | return _get_sa1b_instances_meta() 37 | -------------------------------------------------------------------------------- /whole_image_segmentation/demo.sh: -------------------------------------------------------------------------------- 1 | python demo_whole_image.py \ 2 | --input whole_image_segmentation/examples/sa_628955.jpg \ 3 | --output demo.jpg \ 4 | --opts \ 5 | MODEL.WEIGHTS /home/xudongw/mask2former/output/bs16_lr5e-5_rn50_41json_500masks_2000q_DINO/model_0199999.pth \ 6 | MODEL.DEVICE cpu -------------------------------------------------------------------------------- /whole_image_segmentation/demo_whole_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import argparse 3 | from detectron2.engine import DefaultPredictor, default_setup 4 | from detectron2.projects.deeplab import add_deeplab_config 5 | from detectron2.config import get_cfg 6 | from detectron2.utils.colormap import random_color 7 | from mask2former import add_maskformer2_config 8 | import cv2 9 | import os 10 | from tqdm import tqdm 11 | import PIL.Image as Image 12 | import numpy as np 13 | 14 | def setup(args): 15 | """ 16 | Create configs and perform basic setups. 17 | """ 18 | cfg = get_cfg() 19 | cfg.set_new_allowed(True) 20 | # for poly lr schedule 21 | add_deeplab_config(cfg) 22 | add_maskformer2_config(cfg) 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | cfg.freeze() 26 | default_setup(cfg, args) 27 | return cfg 28 | 29 | def area(mask): 30 | if mask.size == 0: return 0 31 | return np.count_nonzero(mask) / mask.size 32 | 33 | def vis_mask(input, mask, mask_color) : 34 | fg = mask > 0.5 35 | rgb = np.copy(input) 36 | rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8) 37 | return Image.fromarray(rgb) 38 | 39 | def save_image(I, pool, output_path): 40 | # the visualization strategy is small masks on top of large masks 41 | already_painted = np.zeros(np.array(I).shape[:2]) 42 | input = I.copy() 43 | i = 0 44 | for mask in tqdm(pool): 45 | already_painted += mask.astype(np.uint8) 46 | overlap = (already_painted == 2) 47 | if np.sum(overlap) != 0: 48 | input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input)) 49 | already_painted -= overlap 50 | input = vis_mask(input, mask, random_color(rgb=True)) 51 | input.save(output_path) 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") 55 | parser.add_argument( 56 | "--config-file", 57 | default="configs/maskformer2_R50_bs16_50ep.yaml", 58 | metavar="FILE", 59 | help="path to config file", 60 | ) 61 | parser.add_argument("--input", type=str, help="path of input image") 62 | parser.add_argument("--output", type=str, help="path to save output image") 63 | parser.add_argument("--confidence_thresh", type=float, default=0.5, help="path to save output image") 64 | parser.add_argument( 65 | "--opts", 66 | help="Modify config options using the command-line 'KEY VALUE' pairs", 67 | default=[], 68 | nargs=argparse.REMAINDER, 69 | ) 70 | args = parser.parse_args() 71 | 72 | pred = DefaultPredictor(setup(args)) 73 | inputs = cv2.imread(args.input) 74 | pred.input_format = "BGR" 75 | 76 | outputs = pred(inputs)['instances'] 77 | masks = [] 78 | for score, mask in zip(outputs.scores, outputs.pred_masks): 79 | if score < args.confidence_thresh: continue 80 | masks.append(mask.cpu().numpy()) 81 | sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True) 82 | print(f"You have {len(sorted_masks)} masks for this image") 83 | 84 | save_image(inputs, sorted_masks, args.output) 85 | 86 | if __name__ == "__main__": 87 | main() -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import data # register all new datasets 3 | from . import modeling 4 | 5 | # config 6 | from .config import add_maskformer2_config 7 | 8 | # dataset loading 9 | from .data.dataset_mappers.sam_instance_tsv_dataset_mapper import SamInstanceTSVDatasetMapper 10 | from .data.dataset_mappers.sam_instance_tsv_self_train_dataset_mapper import SamSelfTrainTSVDatasetMapper 11 | 12 | # models 13 | from .maskformer_model import MaskFormer 14 | from .test_time_augmentation import SemanticSegmentorWithTTA 15 | 16 | # evaluation 17 | from .evaluation.coco_evaluation import COCOEvaluator 18 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | from detectron2.config import CfgNode as CN 4 | 5 | 6 | def add_maskformer2_config(cfg): 7 | """ 8 | Add config for MASK_FORMER. 9 | """ 10 | # NOTE: configs from original maskformer 11 | # data config 12 | # select the dataset mapper 13 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" 14 | # Color augmentation 15 | cfg.INPUT.COLOR_AUG_SSD = False 16 | # We retry random cropping until no single category in semantic segmentation GT occupies more 17 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 18 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 19 | # Pad image and segmentation GT in dataset mapper. 20 | cfg.INPUT.SIZE_DIVISIBILITY = -1 21 | 22 | # solver config 23 | # weight decay on embedding 24 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 25 | # optimizer 26 | cfg.SOLVER.OPTIMIZER = "ADAMW" 27 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 28 | 29 | # mask_former model config 30 | cfg.MODEL.MASK_FORMER = CN() 31 | 32 | # loss 33 | cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True 34 | cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 35 | cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0 36 | cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 37 | cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 38 | 39 | # transformer config 40 | cfg.MODEL.MASK_FORMER.NHEADS = 8 41 | cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 42 | cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 43 | cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 44 | cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 45 | cfg.MODEL.MASK_FORMER.PRE_NORM = False 46 | 47 | cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 48 | cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 49 | 50 | cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" 51 | cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False 52 | 53 | # mask_former inference config 54 | cfg.MODEL.MASK_FORMER.TEST = CN() 55 | cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True 56 | cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False 57 | cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False 58 | cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 59 | cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 60 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False 61 | 62 | # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) 63 | # you can use this config to override 64 | cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 65 | 66 | # pixel decoder config 67 | cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 68 | # adding transformer in pixel decoder 69 | cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 70 | # pixel decoder 71 | cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" 72 | 73 | # swin transformer backbone 74 | cfg.MODEL.SWIN = CN() 75 | cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 76 | cfg.MODEL.SWIN.PATCH_SIZE = 4 77 | cfg.MODEL.SWIN.EMBED_DIM = 96 78 | cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 79 | cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 80 | cfg.MODEL.SWIN.WINDOW_SIZE = 7 81 | cfg.MODEL.SWIN.MLP_RATIO = 4.0 82 | cfg.MODEL.SWIN.QKV_BIAS = True 83 | cfg.MODEL.SWIN.QK_SCALE = None 84 | cfg.MODEL.SWIN.DROP_RATE = 0.0 85 | cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 86 | cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 87 | cfg.MODEL.SWIN.APE = False 88 | cfg.MODEL.SWIN.PATCH_NORM = True 89 | cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] 90 | cfg.MODEL.SWIN.USE_CHECKPOINT = False 91 | 92 | # NOTE: maskformer2 extra configs 93 | # transformer module 94 | cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder" 95 | 96 | # LSJ aug 97 | cfg.INPUT.IMAGE_SIZE = 1024 98 | cfg.INPUT.MIN_SCALE = 0.1 99 | cfg.INPUT.MAX_SCALE = 2.0 100 | 101 | # MSDeformAttn encoder configs 102 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"] 103 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4 104 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8 105 | 106 | # point loss configs 107 | # Number of points sampled during training for a mask point head. 108 | cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112 109 | # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the 110 | # original paper. 111 | cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0 112 | # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in 113 | # the original paper. 114 | cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75 115 | 116 | # additional configs for DropLoss 117 | cfg.SOLVER.USE_DROP_LOSSES = False 118 | cfg.SOLVER.DROPLOSS_IOU_THRESH = 1.0 119 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/data/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-xwang/UnSAM/463a93c90a4841dcc510eb198defa631ba428637/whole_image_segmentation/mask2former/evaluation/__init__.py -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .backbone.swin import D2SwinTransformer 3 | from .pixel_decoder.fpn import BasePixelDecoder 4 | from .pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder 5 | from .meta_arch.mask_former_head import MaskFormerHead 6 | from .meta_arch.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead 7 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/meta_arch/mask_former_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from copy import deepcopy 4 | from typing import Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import fvcore.nn.weight_init as weight_init 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from detectron2.config import configurable 11 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 12 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 13 | 14 | from ..transformer_decoder.maskformer_transformer_decoder import build_transformer_decoder 15 | from ..pixel_decoder.fpn import build_pixel_decoder 16 | 17 | 18 | @SEM_SEG_HEADS_REGISTRY.register() 19 | class MaskFormerHead(nn.Module): 20 | 21 | _version = 2 22 | 23 | def _load_from_state_dict( 24 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 25 | ): 26 | version = local_metadata.get("version", None) 27 | if version is None or version < 2: 28 | # Do not warn if train from scratch 29 | scratch = True 30 | logger = logging.getLogger(__name__) 31 | for k in list(state_dict.keys()): 32 | newk = k 33 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 34 | newk = k.replace(prefix, prefix + "pixel_decoder.") 35 | # logger.debug(f"{k} ==> {newk}") 36 | if newk != k: 37 | state_dict[newk] = state_dict[k] 38 | del state_dict[k] 39 | scratch = False 40 | 41 | if not scratch: 42 | logger.warning( 43 | f"Weight format of {self.__class__.__name__} have changed! " 44 | "Please upgrade your models. Applying automatic conversion now ..." 45 | ) 46 | 47 | @configurable 48 | def __init__( 49 | self, 50 | input_shape: Dict[str, ShapeSpec], 51 | *, 52 | num_classes: int, 53 | pixel_decoder: nn.Module, 54 | loss_weight: float = 1.0, 55 | ignore_value: int = -1, 56 | # extra parameters 57 | transformer_predictor: nn.Module, 58 | transformer_in_feature: str, 59 | ): 60 | """ 61 | NOTE: this interface is experimental. 62 | Args: 63 | input_shape: shapes (channels and stride) of the input features 64 | num_classes: number of classes to predict 65 | pixel_decoder: the pixel decoder module 66 | loss_weight: loss weight 67 | ignore_value: category id to be ignored during training. 68 | transformer_predictor: the transformer decoder that makes prediction 69 | transformer_in_feature: input feature name to the transformer_predictor 70 | """ 71 | super().__init__() 72 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 73 | self.in_features = [k for k, v in input_shape] 74 | feature_strides = [v.stride for k, v in input_shape] 75 | feature_channels = [v.channels for k, v in input_shape] 76 | 77 | self.ignore_value = ignore_value 78 | self.common_stride = 4 79 | self.loss_weight = loss_weight 80 | 81 | self.pixel_decoder = pixel_decoder 82 | self.predictor = transformer_predictor 83 | self.transformer_in_feature = transformer_in_feature 84 | 85 | self.num_classes = num_classes 86 | 87 | @classmethod 88 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 89 | # figure out in_channels to transformer predictor 90 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder": 91 | transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 92 | elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "pixel_embedding": 93 | transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM 94 | elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder": # for maskformer2 95 | transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 96 | else: 97 | transformer_predictor_in_channels = input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels 98 | 99 | return { 100 | "input_shape": { 101 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 102 | }, 103 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 104 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 105 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 106 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 107 | "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, 108 | "transformer_predictor": build_transformer_decoder( 109 | cfg, 110 | transformer_predictor_in_channels, 111 | mask_classification=True, 112 | ), 113 | } 114 | 115 | def forward(self, features, mask=None): 116 | return self.layers(features, mask) 117 | 118 | def layers(self, features, mask=None): 119 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features) 120 | if self.transformer_in_feature == "multi_scale_pixel_decoder": 121 | predictions = self.predictor(multi_scale_features, mask_features, mask) 122 | else: 123 | if self.transformer_in_feature == "transformer_encoder": 124 | assert ( 125 | transformer_encoder_features is not None 126 | ), "Please use the TransformerEncoderPixelDecoder." 127 | predictions = self.predictor(transformer_encoder_features, mask_features, mask) 128 | elif self.transformer_in_feature == "pixel_embedding": 129 | predictions = self.predictor(mask_features, mask_features, mask) 130 | else: 131 | predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask) 132 | return predictions 133 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd mask2former/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | if torch.cuda.is_available(): 30 | raise ModuleNotFoundError(info_string) 31 | else: 32 | print("if you are running on GPU", info_string) 33 | 34 | 35 | class MSDeformAttnFunction(Function): 36 | @staticmethod 37 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 38 | ctx.im2col_step = im2col_step 39 | output = MSDA.ms_deform_attn_forward( 40 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 41 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 42 | return output 43 | 44 | @staticmethod 45 | @once_differentiable 46 | def backward(ctx, grad_output): 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 48 | grad_value, grad_sampling_loc, grad_attn_weight = \ 49 | MSDA.ms_deform_attn_backward( 50 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 51 | 52 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 53 | 54 | 55 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 56 | # for debug and test only, 57 | # need to use cuda version instead 58 | N_, S_, M_, D_ = value.shape 59 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 60 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 61 | sampling_grids = 2 * sampling_locations - 1 62 | sampling_value_list = [] 63 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 64 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 65 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 66 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 67 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 68 | # N_*M_, D_, Lq_, P_ 69 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 70 | mode='bilinear', padding_mode='zeros', align_corners=False) 71 | sampling_value_list.append(sampling_value_l_) 72 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 73 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 74 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 75 | return output.transpose(1, 2).contiguous() 76 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install 14 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/transformer_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .maskformer_transformer_decoder import StandardTransformerDecoder 3 | from .mask2former_transformer_decoder import MultiScaleMaskedTransformerDecoder 4 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/modeling/transformer_decoder/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | def __repr__(self, _repr_indent=4): 55 | head = "Positional encoding " + self.__class__.__name__ 56 | body = [ 57 | "num_pos_feats: {}".format(self.num_pos_feats), 58 | "temperature: {}".format(self.temperature), 59 | "normalize: {}".format(self.normalize), 60 | "scale: {}".format(self.scale), 61 | ] 62 | # _repr_indent = 4 63 | lines = [head] + [" " * _repr_indent + line for line in body] 64 | return "\n".join(lines) 65 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/test_time_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | from itertools import count 5 | 6 | import numpy as np 7 | import torch 8 | from fvcore.transforms import HFlipTransform 9 | from torch import nn 10 | from torch.nn.parallel import DistributedDataParallel 11 | 12 | from detectron2.data.detection_utils import read_image 13 | from detectron2.modeling import DatasetMapperTTA 14 | 15 | 16 | __all__ = [ 17 | "SemanticSegmentorWithTTA", 18 | ] 19 | 20 | 21 | class SemanticSegmentorWithTTA(nn.Module): 22 | """ 23 | A SemanticSegmentor with test-time augmentation enabled. 24 | Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. 25 | """ 26 | 27 | def __init__(self, cfg, model, tta_mapper=None, batch_size=1): 28 | """ 29 | Args: 30 | cfg (CfgNode): 31 | model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. 32 | tta_mapper (callable): takes a dataset dict and returns a list of 33 | augmented versions of the dataset dict. Defaults to 34 | `DatasetMapperTTA(cfg)`. 35 | batch_size (int): batch the augmented images into this batch size for inference. 36 | """ 37 | super().__init__() 38 | if isinstance(model, DistributedDataParallel): 39 | model = model.module 40 | self.cfg = cfg.clone() 41 | 42 | self.model = model 43 | 44 | if tta_mapper is None: 45 | tta_mapper = DatasetMapperTTA(cfg) 46 | self.tta_mapper = tta_mapper 47 | self.batch_size = batch_size 48 | 49 | def __call__(self, batched_inputs): 50 | """ 51 | Same input/output format as :meth:`SemanticSegmentor.forward` 52 | """ 53 | 54 | def _maybe_read_image(dataset_dict): 55 | ret = copy.copy(dataset_dict) 56 | if "image" not in ret: 57 | image = read_image(ret.pop("file_name"), self.model.input_format) 58 | image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW 59 | ret["image"] = image 60 | if "height" not in ret and "width" not in ret: 61 | ret["height"] = image.shape[1] 62 | ret["width"] = image.shape[2] 63 | return ret 64 | 65 | processed_results = [] 66 | for x in batched_inputs: 67 | result = self._inference_one_image(_maybe_read_image(x)) 68 | processed_results.append(result) 69 | return processed_results 70 | 71 | def _inference_one_image(self, input): 72 | """ 73 | Args: 74 | input (dict): one dataset dict with "image" field being a CHW tensor 75 | Returns: 76 | dict: one output dict 77 | """ 78 | orig_shape = (input["height"], input["width"]) 79 | augmented_inputs, tfms = self._get_augmented_inputs(input) 80 | 81 | final_predictions = None 82 | count_predictions = 0 83 | for input, tfm in zip(augmented_inputs, tfms): 84 | count_predictions += 1 85 | with torch.no_grad(): 86 | if final_predictions is None: 87 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 88 | final_predictions = self.model([input])[0].pop("sem_seg").flip(dims=[2]) 89 | else: 90 | final_predictions = self.model([input])[0].pop("sem_seg") 91 | else: 92 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 93 | final_predictions += self.model([input])[0].pop("sem_seg").flip(dims=[2]) 94 | else: 95 | final_predictions += self.model([input])[0].pop("sem_seg") 96 | 97 | final_predictions = final_predictions / count_predictions 98 | return {"sem_seg": final_predictions} 99 | 100 | def _get_augmented_inputs(self, input): 101 | augmented_inputs = self.tta_mapper(input) 102 | tfms = [x.pop("transforms") for x in augmented_inputs] 103 | return augmented_inputs, tfms 104 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.distributed as dist 12 | import torchvision 13 | from torch import Tensor 14 | 15 | 16 | def _max_by_axis(the_list): 17 | # type: (List[List[int]]) -> List[int] 18 | maxes = the_list[0] 19 | for sublist in the_list[1:]: 20 | for index, item in enumerate(sublist): 21 | maxes[index] = max(maxes[index], item) 22 | return maxes 23 | 24 | 25 | class NestedTensor(object): 26 | def __init__(self, tensors, mask: Optional[Tensor]): 27 | self.tensors = tensors 28 | self.mask = mask 29 | 30 | def to(self, device): 31 | # type: (Device) -> NestedTensor # noqa 32 | cast_tensor = self.tensors.to(device) 33 | mask = self.mask 34 | if mask is not None: 35 | assert mask is not None 36 | cast_mask = mask.to(device) 37 | else: 38 | cast_mask = None 39 | return NestedTensor(cast_tensor, cast_mask) 40 | 41 | def decompose(self): 42 | return self.tensors, self.mask 43 | 44 | def __repr__(self): 45 | return str(self.tensors) 46 | 47 | 48 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 49 | # TODO make this more general 50 | if tensor_list[0].ndim == 3: 51 | if torchvision._is_tracing(): 52 | # nested_tensor_from_tensor_list() does not export well to ONNX 53 | # call _onnx_nested_tensor_from_tensor_list() instead 54 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 55 | 56 | # TODO make it support different-sized images 57 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 58 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 59 | batch_shape = [len(tensor_list)] + max_size 60 | b, c, h, w = batch_shape 61 | dtype = tensor_list[0].dtype 62 | device = tensor_list[0].device 63 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 64 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 65 | for img, pad_img, m in zip(tensor_list, tensor, mask): 66 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 67 | m[: img.shape[1], : img.shape[2]] = False 68 | else: 69 | raise ValueError("not supported") 70 | return NestedTensor(tensor, mask) 71 | 72 | 73 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 74 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 75 | @torch.jit.unused 76 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 77 | max_size = [] 78 | for i in range(tensor_list[0].dim()): 79 | max_size_i = torch.max( 80 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 81 | ).to(torch.int64) 82 | max_size.append(max_size_i) 83 | max_size = tuple(max_size) 84 | 85 | # work around for 86 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 87 | # m[: img.shape[1], :img.shape[2]] = False 88 | # which is not yet supported in onnx 89 | padded_imgs = [] 90 | padded_masks = [] 91 | for img in tensor_list: 92 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 93 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 94 | padded_imgs.append(padded_img) 95 | 96 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 97 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 98 | padded_masks.append(padded_mask.to(torch.bool)) 99 | 100 | tensor = torch.stack(padded_imgs) 101 | mask = torch.stack(padded_masks) 102 | 103 | return NestedTensor(tensor, mask=mask) 104 | 105 | 106 | def is_dist_avail_and_initialized(): 107 | if not dist.is_available(): 108 | return False 109 | if not dist.is_initialized(): 110 | return False 111 | return True 112 | -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/utils/tsv/__init__.py: -------------------------------------------------------------------------------- 1 | from .io_common import FileProgressingbar, img_from_base64, generate_lineidx 2 | from .tsv_io import TSVFile 3 | 4 | __all__ = [ 5 | 'FileProgressingbar', 'img_from_base64', 'generate_lineidx', 'TSVFile' 6 | ] -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/utils/tsv/io_common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-13 14:35:27 4 | # @Last Modified by: Yihao Chen 5 | # @Last Modified time: 2022-04-24 11:38:58 6 | 7 | import os 8 | import base64 9 | from io import BytesIO 10 | from PIL import Image 11 | 12 | import cv2 13 | import yaml 14 | import progressbar 15 | import numpy as np 16 | import torchvision.transforms as T 17 | 18 | class FileProgressingbar: 19 | fileobj = None 20 | pbar = None 21 | def __init__(self, fileobj, msg): 22 | fileobj.seek(0, os.SEEK_END) 23 | flen = fileobj.tell() 24 | fileobj.seek(0, os.SEEK_SET) 25 | self.fileobj = fileobj 26 | widgets = [msg, progressbar.AnimatedMarker(), ' ', progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()] 27 | self.pbar = progressbar.ProgressBar(widgets=widgets, maxval=flen).start() 28 | 29 | def update(self): 30 | self.pbar.update(self.fileobj.tell()) 31 | 32 | 33 | def img_from_base64(imagestring): 34 | jpgbytestring = base64.b64decode(imagestring) 35 | image = BytesIO(jpgbytestring) 36 | image = Image.open(image).convert("RGB") 37 | return image 38 | 39 | # jpgbytestring = base64.b64decode(imagestring) 40 | # nparr = np.frombuffer(jpgbytestring, np.uint8) 41 | # try: 42 | # r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 43 | # # r = cv2.cvtColor(r, cv2.COLOR_BGR2RGB) 44 | # return r 45 | # except: 46 | # return None 47 | 48 | 49 | def generate_lineidx(filein, idxout): 50 | assert not os.path.isfile(idxout) 51 | with open(filein, 'r') as tsvin, open(idxout, 'w') as tsvout: 52 | bar = FileProgressingbar(tsvin, 'Generating lineidx {0}: '.format(idxout)) 53 | fsize = os.fstat(tsvin.fileno()).st_size 54 | fpos = 0 55 | while fpos != fsize: 56 | tsvout.write(str(fpos)+"\n") 57 | tsvin.readline() 58 | fpos = tsvin.tell() 59 | bar.update() -------------------------------------------------------------------------------- /whole_image_segmentation/mask2former/utils/tsv/tsv_io.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import os.path as op 4 | from mask2former.utils.tsv.io_common import generate_lineidx, FileProgressingbar 5 | 6 | 7 | class TSVFile(object): 8 | def __init__(self, tsv_file, silence=True): 9 | self.tsv_file = tsv_file 10 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' 11 | 12 | self.label_file = op.splitext(tsv_file)[0] + '.label' 13 | self.label_lineidx = op.splitext(tsv_file)[0] + '.label.lineidx' 14 | 15 | if os.path.exists(self.label_file): 16 | self.split_label = True 17 | else: 18 | self.split_label = False 19 | 20 | self._fp = None 21 | self._lineidx = None 22 | 23 | self._label_fp = None 24 | self._label_lineidx = None 25 | 26 | self.pid = None 27 | self.silence = silence 28 | self._ensure_lineidx_loaded() 29 | 30 | def num_rows(self): 31 | return len(self._lineidx) 32 | 33 | def seek(self, idx): 34 | self._ensure_tsv_opened() 35 | pos = self._lineidx[idx] 36 | self._fp.seek(pos) 37 | tsv_info = [s.strip() for s in self._fp.readline().split('\t')] 38 | 39 | if self.split_label: 40 | label_pos = self._label_lineidx[idx] 41 | self._label_fp.seek(label_pos) 42 | label_info = [s.strip() for s in self._label_fp.readline().split('\t')] 43 | 44 | assert tsv_info[0] == label_info[0] 45 | tsv_info = [tsv_info[0], label_info[-1], tsv_info[-1]] 46 | 47 | return tsv_info 48 | 49 | def close(self): 50 | if self._fp is not None: 51 | self._fp.close() 52 | del self._fp 53 | del self._lineidx 54 | 55 | self._fp = None 56 | self._lineidx = None 57 | 58 | def _ensure_lineidx_loaded(self): 59 | if not op.isfile(self.lineidx) and not op.islink(self.lineidx): 60 | generate_lineidx(self.tsv_file, self.lineidx) 61 | 62 | if self._lineidx is None: 63 | with open(self.lineidx, 'r') as fp: 64 | lines = fp.readlines() 65 | self._lineidx = [int(i.strip().split()[0]) for i in lines] 66 | 67 | if self.split_label: 68 | with open(self.label_lineidx, 'r') as fp: 69 | lines = fp.readlines() 70 | self._label_lineidx = [int(i.strip().split()[0]) for i in lines] 71 | 72 | 73 | def _ensure_tsv_opened(self): 74 | self._ensure_lineidx_loaded() 75 | if self._fp is None: 76 | self._fp = open(self.tsv_file, 'r') 77 | self.pid = os.getpid() 78 | 79 | if self.split_label: 80 | self._label_fp = open(self.label_file, 'r') 81 | 82 | if self.pid != os.getpid(): 83 | print('re-open {} because the process id changed'.format(self.tsv_file)) 84 | self._fp = open(self.tsv_file, 'r') 85 | self.pid = os.getpid() 86 | 87 | if self.split_label: 88 | self._label_fp = open(self.label_file, 'r') -------------------------------------------------------------------------------- /whole_image_segmentation/train.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=/PATH/TO/YOUR/DATASETS 2 | export TRAIN_DATASETS=/PATH/TO/YOUR/SA-1B 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py \ 4 | --num-gpus 4 \ 5 | --config-file configs/maskformer2_R50_bs16_50ep.yaml \ 6 | SOLVER.IMS_PER_BATCH 4 \ 7 | DATALOADER.NUM_WORKERS 1 \ 8 | OUTPUT_DIR ex_train \ --------------------------------------------------------------------------------