├── .gitignore ├── README.md ├── configs ├── Base.yaml └── S1_point_seg.yaml ├── demo ├── inference.py ├── predictor.py ├── uniovseg_s1_point │ ├── config.yaml │ ├── log.txt │ └── vis │ │ ├── A bear wearing sunglasses and hosting a talk show.jpg │ │ ├── A big moon rises on top of Toronto city.jpg │ │ ├── A bigfoot walking in the snowstorm.jpg │ │ ├── A teddy bear washing dishes.jpg │ │ ├── ADE_val_00001211.jpg │ │ ├── ADE_val_00001404.jpg │ │ ├── ADE_val_00001422.jpg │ │ ├── ADE_val_00001433.jpg │ │ ├── ADE_val_00001579.jpg │ │ ├── ADE_val_00001589.jpg │ │ ├── ADE_val_00001632.jpg │ │ ├── ADE_val_00001812.jpg │ │ ├── ADE_val_00001832.jpg │ │ ├── ADE_val_00001909.jpg │ │ ├── sa_1272468.jpg │ │ ├── sa_140340.jpg │ │ ├── sa_1603668.jpg │ │ ├── sa_1749626.jpg │ │ ├── sa_2191622.jpg │ │ └── sa_576544.jpg └── utils.py ├── images ├── A bear wearing sunglasses and hosting a talk show.jpg ├── A big moon rises on top of Toronto city.jpg ├── A bigfoot walking in the snowstorm.jpg ├── A teddy bear washing dishes.jpg ├── ADE_val_00001211.jpg ├── ADE_val_00001404.jpg ├── ADE_val_00001433.jpg ├── ADE_val_00001579.jpg ├── ADE_val_00001589.jpg ├── ADE_val_00001632.jpg ├── ADE_val_00001812.jpg ├── ADE_val_00001832.jpg ├── ADE_val_00001909.jpg ├── sa_1272468.jpg ├── sa_140340.jpg ├── sa_1603668.jpg ├── sa_1749626.jpg ├── sa_2191622.jpg └── sa_576544.jpg ├── install.sh ├── lib ├── models │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ └── clip.py │ ├── framework │ │ ├── __init__.py │ │ └── uni_ovseg.py │ ├── pixel_decoder │ │ ├── __init__.py │ │ ├── msdeformattn.py │ │ └── ops │ │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ │ ├── make.sh │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ │ ├── setup.py │ │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ │ └── test.py │ ├── prompt_decoder │ │ ├── __init__.py │ │ └── encoder.py │ ├── transformer_decoder │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ └── position_encoding.py │ └── utils.py └── utils │ ├── __init__.py │ ├── config.py │ ├── debug.py │ ├── misc.py │ ├── post_process.py │ ├── prompt.py │ └── test_time_augmentation.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | instant_test_output 4 | inference_test_output 5 | 6 | 7 | *.png 8 | *.json 9 | *.diff 10 | # *.jpg 11 | # !/projects/DensePose/doc/images/*.jpg 12 | 13 | # compilation and distribution 14 | __pycache__ 15 | _ext 16 | *.pyc 17 | *.pyd 18 | *.so 19 | *.dll 20 | *.egg-info/ 21 | build/ 22 | dist/ 23 | wheels/ 24 | 25 | # pytorch/python/numpy formats 26 | *.pth 27 | *.pkl 28 | *.npy 29 | *.ts 30 | *.parquet 31 | model_ts*.txt 32 | rank*.txt 33 | 34 | # ipython/jupyter notebooks 35 | *.ipynb 36 | **/.ipynb_checkpoints/ 37 | 38 | # Editor temporaries 39 | *.swn 40 | *.swo 41 | *.swp 42 | *~ 43 | 44 | # editor settings 45 | .idea 46 | .vscode 47 | _darcs 48 | 49 | # project dirs 50 | /detectron2/model_zoo/configs 51 | /datasets/* 52 | !/datasets/*.* 53 | /projects/*/datasets 54 | /models 55 | /snippet 56 | /detectron2 57 | /results 58 | /all_datasets 59 | /tools/sa1b_* 60 | /tools/*.parquet -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unpair-Seg: Open-Vocabulary Segmentation with Unpaired Mask-Text Supervision 2 | 3 | This repo contains the code for our paper [UnpairSeg](https://derrickwang005.github.io/Unpair-Seg.pytorch/). 4 | It is a weakly supervised open-vocabulary segmentation framework that leverages unpaired mask-text pairs. 5 | 6 | 7 | **Now, we release the inference code and checkpoints for stage one training.** 8 | 9 | 10 | ## Installation 11 | - Linux with Python ≥ 3.10 12 | - PyTorch ≥ 2.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 13 | Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check 14 | PyTorch version matches that is required by Detectron2. 15 | - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). 16 | - OpenCV is optional but needed by demo and visualization 17 | - please check `install.sh` for other dependencies 18 | 19 | 20 | ## Inference 21 | The part provides a brief introduction of the usage of Unpair-Seg. 22 | Please download the [checkpoint](https://drive.google.com/file/d/1LefU25dxFtuPQ5_oA-18_qwKbCQ8wiF9/view?usp=sharing) of stage one training. 23 | We provide `./demo/inference.py` for point-promptable segmentation. 24 | Run it with: 25 | 26 | ``` 27 | cd demo/ 28 | python inference.py \ 29 | -c ../configs/S1_point_seg.yaml \ 30 | -i ../images/*.jpg \ 31 | --opt MODEL.WEIGHTS stage1.pth 32 | ``` 33 | 34 | We also provide some test images under `./images/`. 35 | 36 | ---- 37 | If you use this codebase, or otherwise found our work valuable, please cite: 38 | ``` 39 | @article{wang2024open, 40 | title={Open-Vocabulary Segmentation with Unpaired Mask-Text Supervision}, 41 | author={Wang, Zhaoqing and Xia, Xiaobo and Chen, Ziye and He, Xiao and Guo, Yandong and Gong, Mingming and Liu, Tongliang}, 42 | journal={arXiv preprint arXiv:2402.08960}, 43 | year={2024} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /configs/Base.yaml: -------------------------------------------------------------------------------- 1 | CUDNN_BENCHMARK: true 2 | SEED: 42 3 | OUTPUT_DIR: "" 4 | 5 | GLOBAL: 6 | HACK: 1.0 7 | 8 | DATALOADER: 9 | ASPECT_RATIO_GROUPING: true 10 | FILTER_EMPTY_ANNOTATIONS: true 11 | NUM_WORKERS: 4 12 | REPEAT_THRESHOLD: 0.0 13 | SAMPLER_TRAIN: TrainingSampler 14 | 15 | DATASETS: 16 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000 17 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000 18 | PROPOSAL_FILES_TEST: [] 19 | PROPOSAL_FILES_TRAIN: [] 20 | TRAIN: 21 | - openvocab_coco_2017_train_panoptic_with_sem_seg 22 | TEST: 23 | - openvocab_ade20k_panoptic_val 24 | 25 | INPUT: 26 | DATASET_MAPPER_NAME: sa1b 27 | FEW_SHOT_JSON: 28 | - "" 29 | IMG_SIZE: 1024 30 | CROP_SIZE: 1024 31 | MIN_SCALE: 0.8 32 | MAX_SCALE: 1.2 33 | MIN_AREA_RATIO: 0.001 34 | MAX_AREA_RATIO: 0.8 35 | COLOR_AUG_SSD: true 36 | CROP: 37 | ENABLED: false 38 | SINGLE_CATEGORY_MAX_AREA: 1.0 39 | SIZE: 40 | - 0.9 41 | - 0.9 42 | TYPE: relative_range 43 | FORMAT: RGB 44 | MASK_FORMAT: polygon 45 | MAX_SIZE_TEST: 1333 46 | MAX_SIZE_TRAIN: 1333 47 | MIN_SIZE_TEST: 800 48 | MIN_SIZE_TRAIN: 49 | - 800 50 | MIN_SIZE_TRAIN_SAMPLING: choice 51 | RANDOM_FLIP: horizontal 52 | SIZE_DIVISIBILITY: -1 53 | 54 | MODEL: 55 | META_ARCHITECTURE: UniOVSeg_S1 56 | BACKBONE: 57 | FREEZE_AT: 0 58 | NAME: CLIP 59 | DEVICE: cuda 60 | KEYPOINT_ON: false 61 | LOAD_PROPOSALS: false 62 | MASK_ON: false 63 | PIXEL_MEAN: 64 | - 122.7709383 65 | - 116.7460125 66 | - 104.09373615 67 | PIXEL_STD: 68 | - 68.5005327 69 | - 66.6321579 70 | - 70.32316305 71 | OVSEG: 72 | CLIP_MODEL_NAME: convnext_large_d_320 73 | CLIP_PRETRAINED_WEIGHTS: null 74 | PROMPT_ENCODER_NAME: PromptEncoder 75 | PIXEL_DECODER_NAME: MSDeformAttnPixelDecoder 76 | TRANSFORMER_ENC_LAYERS: 6 77 | COMMON_STRIDE: 4 78 | TRANSFORMER_DECODER_NAME: MultiScaleMaskDecoder 79 | TRANSFORMER_IN_FEATURE: multi_scale_pixel_decoder 80 | DEC_LAYERS: 10 81 | MASK_DIM: 256 82 | CONVS_DIM: 256 83 | EMBED_DIM: 256 84 | CLIP_DIM: 1536 85 | NHEADS: 8 86 | DIM_FEEDFORWARD: 2048 87 | DROPOUT: 0.0 88 | PRE_NORM: false 89 | NORM: GN 90 | ENFORCE_INPUT_PROJ: false 91 | NUM_MASKS: 4 92 | CRITERION_SEG: Many2ManySetCriterion 93 | CRITERION_ALIGN: MaskTextAlignCriterion 94 | MASK_WEIGHT: 2.0 95 | DICE_WEIGHT: 1.0 96 | IOU_WEIGHT: 1.0 97 | ALIGN_WEIGHT: 1.0 98 | MATCHER_NUM_POINTS: 6000 99 | TRAIN_NUM_POINTS: 12544 100 | OVERSAMPLE_RATIO: 3.0 101 | IMPORTANCE_SAMPLE_RATIO: 0.75 102 | LOSS_TOPK: 1.0 103 | SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: true 104 | PTS_PER_SIDE: 105 | - 10 106 | - 10 107 | SIZE_DIVISIBILITY: 32 108 | DEEP_SUPERVISION: true 109 | TEST: 110 | PTS_PER_SIDE: 111 | - 20 112 | - 20 113 | SEMANTIC_ON: false 114 | INSTANCE_ON: false 115 | PANOPTIC_ON: false 116 | MASKCLS_ON: false 117 | AUTOLABEL_ON: true 118 | AUTOLABEL_SAVE: false 119 | AUTOLABEL_TYPE: panoptic-point 120 | OBJECT_MASK_THRESHOLD: 0.7 121 | OVERLAP_THRESHOLD: 0.4 122 | WEIGHTS: '' 123 | 124 | SOLVER: 125 | IMS_PER_BATCH: 56 126 | BASE_LR: 0.0001 127 | WARMUP_FACTOR: 1.0 128 | WARMUP_ITERS: 10 129 | WARMUP_METHOD: linear 130 | MAX_ITER: 110000 131 | STEPS: 132 | - 93500 133 | - 104500 134 | CHECKPOINT_PERIOD: 5000 135 | LR_SCHEDULER_NAME: WarmupMultiStepLR 136 | WEIGHT_DECAY: 0.05 137 | WEIGHT_DECAY_BIAS: null 138 | WEIGHT_DECAY_EMBED: 0.0 139 | WEIGHT_DECAY_NORM: 0.0 140 | AMP: 141 | ENABLED: true 142 | BACKBONE_MULTIPLIER: 0.1 143 | BASE_LR_END: 0.0 144 | BIAS_LR_FACTOR: 1.0 145 | CLIP_GRADIENTS: 146 | CLIP_TYPE: full_model 147 | CLIP_VALUE: 1.0 148 | ENABLED: true 149 | NORM_TYPE: 2.0 150 | GAMMA: 0.1 151 | MOMENTUM: 0.9 152 | NESTEROV: false 153 | NUM_DECAYS: 3 154 | OPTIMIZER: ADAMW 155 | POLY_LR_CONSTANT_ENDING: 0.0 156 | POLY_LR_POWER: 0.9 157 | REFERENCE_WORLD_SIZE: 0 158 | RESCALE_INTERVAL: false 159 | 160 | VERSION: 2 161 | VIS_PERIOD: 0 -------------------------------------------------------------------------------- /configs/S1_point_seg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ./Base.yaml 2 | 3 | OUTPUT_DIR: ./uniovseg_s1_point 4 | 5 | INPUT: 6 | DATASET_MAPPER_NAME: sa1b 7 | IMG_SIZE: 1024 8 | CROP_SIZE: 1024 9 | COLOR_AUG_SSD: true 10 | MIN_AREA_RATIO: 0.0001 11 | MAX_AREA_RATIO: 0.8 12 | 13 | MODEL: 14 | WEIGHTS: '' 15 | META_ARCHITECTURE: UniOVSeg_S1 16 | # backbone part. 17 | BACKBONE: 18 | NAME: "CLIP" 19 | OVSEG: 20 | CLIP_MODEL_NAME: convnext_large_d_320 21 | CLIP_PRETRAINED_WEIGHTS: "" 22 | TRANSFORMER_DECODER_NAME: MultiScaleMaskDecoder 23 | PROMPT_ENCODER_NAME: PromptEncoder 24 | TRANSFORMER_ENC_LAYERS: 6 25 | DEC_LAYERS: 10 26 | MASK_WEIGHT: 2.0 27 | DICE_WEIGHT: 1.0 28 | IOU_WEIGHT: 1.0 29 | PTS_PER_SIDE: 30 | - 10 31 | - 10 32 | TEST: 33 | PTS_PER_SIDE: 34 | - 20 35 | - 20 36 | SEMANTIC_ON: false 37 | INSTANCE_ON: false 38 | PANOPTIC_ON: false 39 | MASKCLS_ON: false 40 | AUTOLABEL_ON: true 41 | AUTOLABEL_SAVE: false 42 | AUTOLABEL_TYPE: panoptic-point 43 | OBJECT_MASK_THRESHOLD: 0.7 44 | OVERLAP_THRESHOLD: 0.4 45 | -------------------------------------------------------------------------------- /demo/inference.py: -------------------------------------------------------------------------------- 1 | try: 2 | # ignore ShapelyDeprecationWarning from fvcore 3 | import warnings 4 | from shapely.errors import ShapelyDeprecationWarning 5 | 6 | warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) 7 | except: 8 | pass 9 | import argparse 10 | from time import time 11 | from glob import glob 12 | import os 13 | import torch 14 | import torch.nn.functional as F 15 | import detectron2.utils.comm as comm 16 | from detectron2.config import get_cfg 17 | from detectron2.engine import default_setup 18 | from detectron2.utils.logger import setup_logger 19 | from detectron2.utils.visualizer import Visualizer 20 | import lib.models 21 | from lib.utils import add_ovseg_config 22 | from predictor import StageOnePredictor 23 | from utils import calculate_stability_score, remove_small_regions 24 | 25 | 26 | def get_parser(): 27 | parser = argparse.ArgumentParser(description="Uni-OVSeg inference demo") 28 | parser.add_argument( 29 | "-c", 30 | "--config-file", 31 | metavar="FILE", 32 | help="path to config file", 33 | ) 34 | parser.add_argument( 35 | "-i", 36 | "--input", 37 | nargs="+", 38 | help="A list of space separated input images; " 39 | "or a single glob pattern such as 'directory/*.jpg'", 40 | ) 41 | parser.add_argument( 42 | "--opts", 43 | help="Modify config options using the command-line 'KEY VALUE' pairs", 44 | default=[], 45 | nargs=argparse.REMAINDER, 46 | ) 47 | return parser 48 | 49 | 50 | def setup(args): 51 | """ 52 | Create configs and perform basic setups. 53 | """ 54 | cfg = get_cfg() 55 | add_ovseg_config(cfg) 56 | cfg.merge_from_file(args.config_file) 57 | cfg.merge_from_list(args.opts) 58 | cfg.freeze() 59 | default_setup(cfg, args) 60 | # Setup logger for "uniovseg" module 61 | logger = setup_logger(output=cfg.OUTPUT_DIR, name="uniovseg") 62 | return cfg, logger 63 | 64 | 65 | if __name__ == "__main__": 66 | args = get_parser().parse_args() 67 | cfg, logger = setup(args) 68 | logger.info("Arguments: " + str(args)) 69 | 70 | predictor = StageOnePredictor(cfg, torch.bfloat16, "cuda") 71 | 72 | if len(args.input) == 1: 73 | args.input = glob(os.path.expanduser(args.input[0])) 74 | assert args.input, "The input path(s) was not found" 75 | 76 | args.input.sort() 77 | os.makedirs(os.path.join(cfg.OUTPUT_DIR, "vis"), exist_ok=True) 78 | for path in args.input: 79 | start_time = time() 80 | result = predictor(path) 81 | image = result.pop("image") 82 | prediction = result.pop("prediction") 83 | logger.info( 84 | "{}: segmented {} instances in {:.2f}s".format( 85 | path, 86 | len(prediction), 87 | time() - start_time, 88 | ) 89 | ) 90 | # filter - stablility 91 | stable_score = calculate_stability_score(prediction, 0.5, 0.1) 92 | keep = stable_score > 0.92 93 | prediction = prediction[keep] 94 | # filter - small disconnected regions and holes 95 | prediction = prediction.sigmoid().ge(0.5).cpu().numpy().astype(int) 96 | tmp_masks, scores = [], [] 97 | for pred in prediction: 98 | mask, changed = remove_small_regions(pred, 15, mode="holes") 99 | unchanged = not changed 100 | mask, changed = remove_small_regions(mask, 15, mode="islands") 101 | unchanged = unchanged and not changed 102 | tmp_masks.append(torch.as_tensor(mask).unsqueeze(0)) 103 | # Give score=0 to changed masks and score=1 to unchanged masks 104 | # so NMS will prefer ones that didn't need postprocessing 105 | scores.append(float(unchanged)) 106 | prediction = torch.cat(tmp_masks, dim=0) 107 | 108 | # visualize 109 | max_side_len = max(prediction.shape[-2:]) 110 | image = F.interpolate( 111 | image.unsqueeze(0).float(), 112 | size=(max_side_len, max_side_len), 113 | mode="bilinear", 114 | align_corners=False, 115 | ).squeeze(0) 116 | image = ( 117 | image[:, : prediction.shape[1], : prediction.shape[2]] 118 | .permute(1, 2, 0) 119 | .cpu() 120 | .numpy() 121 | ) 122 | vis = Visualizer(image, metadata=None) 123 | vis.overlay_instances( 124 | masks=prediction, 125 | alpha=0.5, 126 | ) 127 | vis = vis.get_output() 128 | vis.save(os.path.join(cfg.OUTPUT_DIR, "vis", os.path.basename(path))) 129 | del vis 130 | -------------------------------------------------------------------------------- /demo/predictor.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | from functools import partial 4 | import numpy as np 5 | import torch 6 | from detectron2.checkpoint import DetectionCheckpointer 7 | from detectron2.data import detection_utils as utils 8 | from detectron2.data import transforms as T 9 | from detectron2.modeling import build_model 10 | 11 | 12 | def inference_transform(img_path, transform): 13 | output = dict() 14 | 15 | # load image 16 | image = utils.read_image(img_path, format="RGB") 17 | ori_shape = image.shape[:2] 18 | output["height"] = ori_shape[0] 19 | output["width"] = ori_shape[1] 20 | 21 | # preprocess 22 | image, _ = T.apply_transform_gens(transform, image) 23 | output["image"] = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) 24 | 25 | return output 26 | 27 | 28 | class StageOnePredictor: 29 | def __init__(self, cfg, dtype=torch.bfloat16, device="cuda"): 30 | self.cfg = cfg.clone() # cfg can be modified by model 31 | self.model = build_model(self.cfg) 32 | self.model.eval() 33 | print("Model:\n{}".format(self.model)) 34 | 35 | checkpointer = DetectionCheckpointer(self.model) 36 | msg = checkpointer.load(cfg.MODEL.WEIGHTS) 37 | print(msg) 38 | 39 | augments = [ 40 | T.ResizeShortestEdge(cfg.INPUT.CROP_SIZE, cfg.INPUT.CROP_SIZE), 41 | T.FixedSizeCrop( 42 | crop_size=(cfg.INPUT.CROP_SIZE, cfg.INPUT.CROP_SIZE), seg_pad_value=0 43 | ), 44 | ] 45 | self.transform = partial(inference_transform, transform=augments) 46 | 47 | self.dtype = dtype 48 | self.device = device 49 | 50 | @torch.inference_mode() 51 | def __call__(self, image_path): 52 | image = self.transform(image_path) 53 | image["image"] = image["image"].to(self.device) 54 | with torch.cuda.amp.autocast( 55 | enabled=self.dtype == torch.bfloat16, dtype=self.dtype 56 | ): 57 | prediction = self.model([image])[0]["proposal"] 58 | image["prediction"] = prediction 59 | return image 60 | -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/config.yaml: -------------------------------------------------------------------------------- 1 | CUDNN_BENCHMARK: true 2 | DATALOADER: 3 | ASPECT_RATIO_GROUPING: true 4 | FILTER_EMPTY_ANNOTATIONS: true 5 | NUM_WORKERS: 4 6 | REPEAT_THRESHOLD: 0.0 7 | SAMPLER_TRAIN: TrainingSampler 8 | DATASETS: 9 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000 10 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000 11 | PROPOSAL_FILES_TEST: [] 12 | PROPOSAL_FILES_TRAIN: [] 13 | TEST: 14 | - openvocab_ade20k_panoptic_val 15 | TRAIN: 16 | - openvocab_coco_2017_train_panoptic_with_sem_seg 17 | GLOBAL: 18 | HACK: 1.0 19 | INPUT: 20 | COLOR_AUG_SSD: true 21 | CROP: 22 | ENABLED: false 23 | SINGLE_CATEGORY_MAX_AREA: 1.0 24 | SIZE: 25 | - 0.9 26 | - 0.9 27 | TYPE: relative_range 28 | CROP_SIZE: 1024 29 | DATASET_JSON: /vepfs/home/wangzhaoqing/uni-ovseg/sa1b.json 30 | DATASET_MAPPER_NAME: sa1b 31 | DATASET_ROOT: /datasets/sharegpt4v 32 | DATASET_URL: 33 | - - /datasets/SA-1B/split1-2m 34 | - datadict_0p5.parquet 35 | FEW_SHOT_JSON: 36 | - '' 37 | FORMAT: RGB 38 | IMG_SIZE: 1024 39 | MASK_FORMAT: polygon 40 | MAX_AREA_RATIO: 0.8 41 | MAX_INSTANCE: 40 42 | MAX_SCALE: 1.2 43 | MAX_SIZE_TEST: 1333 44 | MAX_SIZE_TRAIN: 1333 45 | MIN_AREA_RATIO: 0.0001 46 | MIN_SCALE: 0.8 47 | MIN_SIZE_TEST: 800 48 | MIN_SIZE_TRAIN: 49 | - 800 50 | MIN_SIZE_TRAIN_SAMPLING: choice 51 | RANDOM_FLIP: horizontal 52 | SIZE_DIVISIBILITY: -1 53 | MODEL: 54 | ANCHOR_GENERATOR: 55 | ANGLES: 56 | - - -90 57 | - 0 58 | - 90 59 | ASPECT_RATIOS: 60 | - - 0.5 61 | - 1.0 62 | - 2.0 63 | NAME: DefaultAnchorGenerator 64 | OFFSET: 0.0 65 | SIZES: 66 | - - 32 67 | - 64 68 | - 128 69 | - 256 70 | - 512 71 | BACKBONE: 72 | FREEZE_AT: 0 73 | NAME: CLIP 74 | DEVICE: cuda 75 | FPN: 76 | FUSE_TYPE: sum 77 | IN_FEATURES: [] 78 | NORM: '' 79 | OUT_CHANNELS: 256 80 | KEYPOINT_ON: false 81 | LOAD_PROPOSALS: false 82 | MASK_ON: false 83 | META_ARCHITECTURE: UniOVSeg_S1 84 | OVSEG: 85 | ALIGN_WEIGHT: 1.0 86 | AUX_MODEL_NAME: convnext_xxlarge 87 | AUX_PRETRAINED_WEIGHTS: /workspace/pretrains/convnext_xxlarge.laion2B-s34B-b82K-augreg-soup.pth 88 | CLIP_DIM: 1536 89 | CLIP_MODEL_NAME: convnext_large_d_320 90 | CLIP_PRETRAINED_WEIGHTS: '' 91 | COMMON_STRIDE: 4 92 | CONVS_DIM: 256 93 | CRITERION_ALIGN: MaskTextAlignCriterion 94 | CRITERION_SEG: Many2ManySetCriterion 95 | DEC_LAYERS: 10 96 | DEEP_SUPERVISION: true 97 | DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: 98 | - res3 99 | - res4 100 | - res5 101 | DICE_WEIGHT: 1.0 102 | DIM_FEEDFORWARD: 2048 103 | DROPOUT: 0.0 104 | EMBED_DIM: 256 105 | ENFORCE_INPUT_PROJ: false 106 | IMPORTANCE_SAMPLE_RATIO: 0.75 107 | INPUT_SIZES: 108 | - 896 109 | - 1024 110 | IN_FEATURES: 111 | - res2 112 | - res3 113 | - res4 114 | - res5 115 | IOU_WEIGHT: 1.0 116 | LORA_INIT: false 117 | LOSS_TOPK: 1.0 118 | MASK_DIM: 256 119 | MASK_WEIGHT: 2.0 120 | MATCHER_NUM_POINTS: 6000 121 | MATCHER_THRES_POS: 0.7 122 | NHEADS: 8 123 | NORM: GN 124 | NUM_MASKS: 4 125 | OVERSAMPLE_RATIO: 3.0 126 | PIXEL_DECODER_NAME: MSDeformAttnPixelDecoder 127 | PRE_NORM: false 128 | PROMPT_ENCODER_NAME: PromptEncoder 129 | PTS_PER_SIDE: 130 | - 10 131 | - 10 132 | RANK: 8 133 | SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: true 134 | SIZE_DIVISIBILITY: 32 135 | TEST: 136 | AUTOLABEL_ON: true 137 | AUTOLABEL_SAVE: false 138 | AUTOLABEL_TYPE: panoptic-point 139 | INSTANCE_ON: false 140 | MASKCLS_ON: false 141 | OBJECT_MASK_THRESHOLD: 0.7 142 | OVERLAP_THRESHOLD: 0.4 143 | PANOPTIC_ON: false 144 | PTS_PER_SIDE: 145 | - 20 146 | - 20 147 | SEMANTIC_ON: false 148 | TRAIN_NUM_POINTS: 12544 149 | TRANSFORMER_DECODER_NAME: MultiScaleMaskDecoder 150 | TRANSFORMER_ENC_LAYERS: 6 151 | TRANSFORMER_IN_FEATURE: multi_scale_pixel_decoder 152 | PANOPTIC_FPN: 153 | COMBINE: 154 | ENABLED: true 155 | INSTANCES_CONFIDENCE_THRESH: 0.5 156 | OVERLAP_THRESH: 0.5 157 | STUFF_AREA_LIMIT: 4096 158 | INSTANCE_LOSS_WEIGHT: 1.0 159 | PIXEL_MEAN: 160 | - 122.7709383 161 | - 116.7460125 162 | - 104.09373615 163 | PIXEL_STD: 164 | - 68.5005327 165 | - 66.6321579 166 | - 70.32316305 167 | PROPOSAL_GENERATOR: 168 | MIN_SIZE: 0 169 | NAME: RPN 170 | RESNETS: 171 | DEFORM_MODULATED: false 172 | DEFORM_NUM_GROUPS: 1 173 | DEFORM_ON_PER_STAGE: 174 | - false 175 | - false 176 | - false 177 | - false 178 | DEPTH: 50 179 | NORM: FrozenBN 180 | NUM_GROUPS: 1 181 | OUT_FEATURES: 182 | - res4 183 | RES2_OUT_CHANNELS: 256 184 | RES5_DILATION: 1 185 | STEM_OUT_CHANNELS: 64 186 | STRIDE_IN_1X1: true 187 | WIDTH_PER_GROUP: 64 188 | RETINANET: 189 | BBOX_REG_LOSS_TYPE: smooth_l1 190 | BBOX_REG_WEIGHTS: &id002 191 | - 1.0 192 | - 1.0 193 | - 1.0 194 | - 1.0 195 | FOCAL_LOSS_ALPHA: 0.25 196 | FOCAL_LOSS_GAMMA: 2.0 197 | IN_FEATURES: 198 | - p3 199 | - p4 200 | - p5 201 | - p6 202 | - p7 203 | IOU_LABELS: 204 | - 0 205 | - -1 206 | - 1 207 | IOU_THRESHOLDS: 208 | - 0.4 209 | - 0.5 210 | NMS_THRESH_TEST: 0.5 211 | NORM: '' 212 | NUM_CLASSES: 80 213 | NUM_CONVS: 4 214 | PRIOR_PROB: 0.01 215 | SCORE_THRESH_TEST: 0.05 216 | SMOOTH_L1_LOSS_BETA: 0.1 217 | TOPK_CANDIDATES_TEST: 1000 218 | ROI_BOX_CASCADE_HEAD: 219 | BBOX_REG_WEIGHTS: 220 | - &id001 221 | - 10.0 222 | - 10.0 223 | - 5.0 224 | - 5.0 225 | - - 20.0 226 | - 20.0 227 | - 10.0 228 | - 10.0 229 | - - 30.0 230 | - 30.0 231 | - 15.0 232 | - 15.0 233 | IOUS: 234 | - 0.5 235 | - 0.6 236 | - 0.7 237 | ROI_BOX_HEAD: 238 | BBOX_REG_LOSS_TYPE: smooth_l1 239 | BBOX_REG_LOSS_WEIGHT: 1.0 240 | BBOX_REG_WEIGHTS: *id001 241 | CLS_AGNOSTIC_BBOX_REG: false 242 | CONV_DIM: 256 243 | FC_DIM: 1024 244 | FED_LOSS_FREQ_WEIGHT_POWER: 0.5 245 | FED_LOSS_NUM_CLASSES: 50 246 | NAME: '' 247 | NORM: '' 248 | NUM_CONV: 0 249 | NUM_FC: 0 250 | POOLER_RESOLUTION: 14 251 | POOLER_SAMPLING_RATIO: 0 252 | POOLER_TYPE: ROIAlignV2 253 | SMOOTH_L1_BETA: 0.0 254 | TRAIN_ON_PRED_BOXES: false 255 | USE_FED_LOSS: false 256 | USE_SIGMOID_CE: false 257 | ROI_HEADS: 258 | BATCH_SIZE_PER_IMAGE: 512 259 | IN_FEATURES: 260 | - res4 261 | IOU_LABELS: 262 | - 0 263 | - 1 264 | IOU_THRESHOLDS: 265 | - 0.5 266 | NAME: Res5ROIHeads 267 | NMS_THRESH_TEST: 0.5 268 | NUM_CLASSES: 80 269 | POSITIVE_FRACTION: 0.25 270 | PROPOSAL_APPEND_GT: true 271 | SCORE_THRESH_TEST: 0.05 272 | ROI_KEYPOINT_HEAD: 273 | CONV_DIMS: 274 | - 512 275 | - 512 276 | - 512 277 | - 512 278 | - 512 279 | - 512 280 | - 512 281 | - 512 282 | LOSS_WEIGHT: 1.0 283 | MIN_KEYPOINTS_PER_IMAGE: 1 284 | NAME: KRCNNConvDeconvUpsampleHead 285 | NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS: true 286 | NUM_KEYPOINTS: 17 287 | POOLER_RESOLUTION: 14 288 | POOLER_SAMPLING_RATIO: 0 289 | POOLER_TYPE: ROIAlignV2 290 | ROI_MASK_HEAD: 291 | CLS_AGNOSTIC_MASK: false 292 | CONV_DIM: 256 293 | NAME: MaskRCNNConvUpsampleHead 294 | NORM: '' 295 | NUM_CONV: 0 296 | POOLER_RESOLUTION: 14 297 | POOLER_SAMPLING_RATIO: 0 298 | POOLER_TYPE: ROIAlignV2 299 | RPN: 300 | BATCH_SIZE_PER_IMAGE: 256 301 | BBOX_REG_LOSS_TYPE: smooth_l1 302 | BBOX_REG_LOSS_WEIGHT: 1.0 303 | BBOX_REG_WEIGHTS: *id002 304 | BOUNDARY_THRESH: -1 305 | CONV_DIMS: 306 | - -1 307 | HEAD_NAME: StandardRPNHead 308 | IN_FEATURES: 309 | - res4 310 | IOU_LABELS: 311 | - 0 312 | - -1 313 | - 1 314 | IOU_THRESHOLDS: 315 | - 0.3 316 | - 0.7 317 | LOSS_WEIGHT: 1.0 318 | NMS_THRESH: 0.7 319 | POSITIVE_FRACTION: 0.5 320 | POST_NMS_TOPK_TEST: 1000 321 | POST_NMS_TOPK_TRAIN: 2000 322 | PRE_NMS_TOPK_TEST: 6000 323 | PRE_NMS_TOPK_TRAIN: 12000 324 | SMOOTH_L1_BETA: 0.0 325 | SEM_SEG_HEAD: 326 | COMMON_STRIDE: 4 327 | CONVS_DIM: 128 328 | IGNORE_VALUE: 255 329 | IN_FEATURES: 330 | - p2 331 | - p3 332 | - p4 333 | - p5 334 | LOSS_WEIGHT: 1.0 335 | NAME: SemSegFPNHead 336 | NORM: GN 337 | NUM_CLASSES: 54 338 | WEIGHTS: stage1.pth 339 | OUTPUT_DIR: ./uniovseg_s1_point 340 | SEED: 42 341 | SOLVER: 342 | AMP: 343 | ENABLED: true 344 | BACKBONE_MULTIPLIER: 0.1 345 | BASE_LR: 0.0001 346 | BASE_LR_END: 0.0 347 | BIAS_LR_FACTOR: 1.0 348 | CHECKPOINT_PERIOD: 5000 349 | CLIP_GRADIENTS: 350 | CLIP_TYPE: full_model 351 | CLIP_VALUE: 1.0 352 | ENABLED: true 353 | NORM_TYPE: 2.0 354 | GAMMA: 0.1 355 | IMS_PER_BATCH: 56 356 | LR_SCHEDULER_NAME: WarmupMultiStepLR 357 | MAX_ITER: 110000 358 | MOMENTUM: 0.9 359 | NESTEROV: false 360 | NUM_DECAYS: 3 361 | OPTIMIZER: ADAMW 362 | POLY_LR_CONSTANT_ENDING: 0.0 363 | POLY_LR_POWER: 0.9 364 | REFERENCE_WORLD_SIZE: 0 365 | RESCALE_INTERVAL: false 366 | STEPS: 367 | - 93500 368 | - 104500 369 | WARMUP_FACTOR: 1.0 370 | WARMUP_ITERS: 10 371 | WARMUP_METHOD: linear 372 | WEIGHT_DECAY: 0.05 373 | WEIGHT_DECAY_BIAS: null 374 | WEIGHT_DECAY_EMBED: 0.0 375 | WEIGHT_DECAY_NORM: 0.0 376 | TEST: 377 | AUG: 378 | ENABLED: false 379 | FLIP: true 380 | MAX_SIZE: 4000 381 | MIN_SIZES: 382 | - 400 383 | - 500 384 | - 600 385 | - 700 386 | - 800 387 | - 900 388 | - 1000 389 | - 1100 390 | - 1200 391 | DETECTIONS_PER_IMAGE: 100 392 | EVAL_PERIOD: 0 393 | EXPECTED_RESULTS: [] 394 | KEYPOINT_OKS_SIGMAS: [] 395 | PRECISE_BN: 396 | ENABLED: false 397 | NUM_ITER: 200 398 | VERSION: 2 399 | VIS_PERIOD: 0 400 | -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/A bear wearing sunglasses and hosting a talk show.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/A bear wearing sunglasses and hosting a talk show.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/A big moon rises on top of Toronto city.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/A big moon rises on top of Toronto city.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/A bigfoot walking in the snowstorm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/A bigfoot walking in the snowstorm.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/A teddy bear washing dishes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/A teddy bear washing dishes.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001211.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001211.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001404.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001404.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001422.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001422.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001433.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001433.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001579.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001579.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001589.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001589.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001632.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001632.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001812.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001812.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001832.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001832.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/ADE_val_00001909.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/ADE_val_00001909.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/sa_1272468.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/sa_1272468.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/sa_140340.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/sa_140340.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/sa_1603668.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/sa_1603668.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/sa_1749626.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/sa_1749626.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/sa_2191622.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/sa_2191622.jpg -------------------------------------------------------------------------------- /demo/uniovseg_s1_point/vis/sa_576544.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/demo/uniovseg_s1_point/vis/sa_576544.jpg -------------------------------------------------------------------------------- /demo/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | 6 | def calculate_stability_score( 7 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 8 | ) -> torch.Tensor: 9 | """ 10 | Computes the stability score for a batch of masks. The stability 11 | score is the IoU between the binary masks obtained by thresholding 12 | the predicted mask logits at high and low values. 13 | """ 14 | # One mask is always contained inside the other. 15 | # Save memory by preventing unnecessary cast to torch.int64 16 | intersections = ( 17 | (masks > (mask_threshold + threshold_offset)) 18 | .sum(-1, dtype=torch.int16) 19 | .sum(-1, dtype=torch.int32) 20 | ) 21 | unions = ( 22 | (masks > (mask_threshold - threshold_offset)) 23 | .sum(-1, dtype=torch.int16) 24 | .sum(-1, dtype=torch.int32) 25 | ) 26 | return intersections / unions 27 | 28 | 29 | def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str): 30 | """ 31 | Removes small disconnected regions and holes in a mask. Returns the 32 | mask and an indicator of if the mask has been modified. 33 | """ 34 | assert mode in ["holes", "islands"] 35 | correct_holes = mode == "holes" 36 | working_mask = (correct_holes ^ mask).astype(np.uint8) 37 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 38 | sizes = stats[:, -1][1:] # Row 0 is background label 39 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 40 | if len(small_regions) == 0: 41 | return mask, False 42 | fill_labels = [0] + small_regions 43 | if not correct_holes: 44 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 45 | # If every region is below threshold, keep largest 46 | if len(fill_labels) == 0: 47 | fill_labels = [int(np.argmax(sizes)) + 1] 48 | mask = np.isin(regions, fill_labels) 49 | return mask, True 50 | -------------------------------------------------------------------------------- /images/A bear wearing sunglasses and hosting a talk show.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/A bear wearing sunglasses and hosting a talk show.jpg -------------------------------------------------------------------------------- /images/A big moon rises on top of Toronto city.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/A big moon rises on top of Toronto city.jpg -------------------------------------------------------------------------------- /images/A bigfoot walking in the snowstorm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/A bigfoot walking in the snowstorm.jpg -------------------------------------------------------------------------------- /images/A teddy bear washing dishes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/A teddy bear washing dishes.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001211.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001211.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001404.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001404.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001433.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001433.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001579.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001579.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001589.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001589.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001632.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001632.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001812.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001812.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001832.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001832.jpg -------------------------------------------------------------------------------- /images/ADE_val_00001909.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/ADE_val_00001909.jpg -------------------------------------------------------------------------------- /images/sa_1272468.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/sa_1272468.jpg -------------------------------------------------------------------------------- /images/sa_140340.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/sa_140340.jpg -------------------------------------------------------------------------------- /images/sa_1603668.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/sa_1603668.jpg -------------------------------------------------------------------------------- /images/sa_1749626.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/sa_1749626.jpg -------------------------------------------------------------------------------- /images/sa_2191622.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/sa_2191622.jpg -------------------------------------------------------------------------------- /images/sa_576544.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickWang005/Unpair-Seg.pytorch/0a3fcfcf08eac14a47172f7d86311a5158b77979/images/sa_576544.jpg -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | cd lib/models/pixel_decoder/ops 3 | sh make.sh 4 | cd ../../../.. 5 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import CLIP 2 | from .framework import UniOVSeg_S1 3 | from .pixel_decoder import build_pixel_decoder 4 | from .prompt_decoder import build_prompt_encoder 5 | from .transformer_decoder import build_transformer_decoder -------------------------------------------------------------------------------- /lib/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import CLIP, build_main_backbone -------------------------------------------------------------------------------- /lib/models/backbone/clip.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec 6 | from detectron2.utils import comm 7 | from detectron2.config import configurable 8 | 9 | 10 | def build_main_backbone(cfg): 11 | """ 12 | Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. 13 | """ 14 | model_name = cfg.MODEL.OVSEG.CLIP_MODEL_NAME 15 | pretrained = cfg.MODEL.OVSEG.CLIP_PRETRAINED_WEIGHTS 16 | backbone_name = cfg.MODEL.BACKBONE.NAME 17 | backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, model_name, pretrained) 18 | assert isinstance(backbone, Backbone) 19 | return backbone 20 | 21 | 22 | @BACKBONE_REGISTRY.register() 23 | class CLIP(Backbone): 24 | @configurable 25 | def __init__(self, model_name, pretrained): 26 | super().__init__() 27 | # download on local rank 0 first 28 | if comm.get_local_rank() == 0: 29 | open_clip.create_model_and_transforms(model_name, pretrained=pretrained) 30 | comm.synchronize() 31 | 32 | self.model_name = model_name 33 | self.pretrained = pretrained 34 | self.clip_model, _, _ = open_clip.create_model_and_transforms( 35 | model_name, pretrained=pretrained 36 | ) 37 | self.text_tokenizer = open_clip.get_tokenizer(model_name) 38 | 39 | model_name = model_name.lower() 40 | assert "convnext_" in model_name, "Only convnext models are supported" 41 | self.model_type = "convnext" 42 | if "_base" in model_name: 43 | self.output_channels = [128, 128, 256, 512, 1024] 44 | elif "_large" in model_name: 45 | self.output_channels = [192, 192, 384, 768, 1536] 46 | elif "_xxlarge" in model_name: 47 | self.output_channels = [384, 384, 768, 1536, 3072] 48 | else: 49 | raise ValueError(f"Unknown model name: {model_name}") 50 | 51 | self._out_feature_strides = { 52 | "stem": 2, 53 | "res2": 4, 54 | "res3": 8, 55 | "res4": 16, 56 | "res5": 32, 57 | "clip_embedding": -1, 58 | } 59 | self._out_feature_channels = { 60 | "stem": self.output_channels[0], 61 | "res2": self.output_channels[1], 62 | "res3": self.output_channels[2], 63 | "res4": self.output_channels[3], 64 | "res5": self.output_channels[4], 65 | "clip_embedding": self.dim_latent, 66 | } 67 | self.freeze_everything() 68 | 69 | @classmethod 70 | def from_config(cls, cfg, model_name, pretrained): 71 | ret = {} 72 | ret["model_name"] = model_name 73 | ret["pretrained"] = pretrained 74 | 75 | return ret 76 | 77 | def freeze_everything(self): 78 | self.eval() 79 | for param in self.parameters(): 80 | param.requires_grad = False 81 | 82 | def encode_text(self, text, normalize: bool = False): 83 | cast_dtype = self.clip_model.transformer.get_cast_dtype() 84 | 85 | x = self.clip_model.token_embedding(text).to( 86 | cast_dtype 87 | ) # [batch_size, n_ctx, d_model] 88 | 89 | x = x + self.clip_model.positional_embedding.to(cast_dtype) 90 | x = x.permute(1, 0, 2) # NLD -> LND 91 | x = self.clip_model.transformer(x, attn_mask=self.clip_model.attn_mask) 92 | x = x.permute(1, 0, 2) # LND -> NLD 93 | x = self.clip_model.ln_final(x) # [batch_size, n_ctx, transformer.width] 94 | # take features from the eot embedding (eot_token is the highest number in each sequence) 95 | x = ( 96 | x[torch.arange(x.shape[0]), text.argmax(dim=-1)] 97 | @ self.clip_model.text_projection 98 | ) 99 | return F.normalize(x, dim=-1) if normalize else x 100 | 101 | def tokenize_text(self, text): 102 | return self.text_tokenizer(text) 103 | 104 | def get_text_projection(self, x): 105 | return x @ self.clip_model.text_projection 106 | 107 | def extract_features(self, x): 108 | out = {} 109 | x = self.clip_model.visual.trunk.stem(x) 110 | out["stem"] = x.contiguous() # os4 111 | for i in range(4): 112 | x = self.clip_model.visual.trunk.stages[i](x) 113 | out[f"res{i+2}"] = ( 114 | x.contiguous() 115 | ) # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) 116 | 117 | x = self.clip_model.visual.trunk.norm_pre(x) 118 | out["clip_vis_dense"] = x.contiguous() 119 | return out 120 | 121 | def visual_prediction_forward(self, x, masks=None): 122 | batch, num_query, channel = x.shape 123 | x = x.reshape(batch * num_query, channel, 1, 1) # fake 2D input 124 | x = self.clip_model.visual.trunk.head(x) 125 | x = self.clip_model.visual.head(x) 126 | return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 127 | 128 | @torch.no_grad() 129 | def get_text_classifier(self, text_list, device): 130 | self.eval() 131 | # reference for templates: https://github.com/mlfoundations/open_clip/blob/91f6cce16b7bee90b3b5d38ca305b5b3b67cc200/src/training/imagenet_zeroshot_data.py 132 | text_tokens = self.tokenize_text(text_list) 133 | text_tokens = text_tokens.to(device) 134 | # we return un-normalized text feature. 135 | text_features = self.encode_text(text_tokens, normalize=False) 136 | return text_features 137 | 138 | @torch.no_grad() 139 | def forward(self, x): 140 | self.eval() 141 | return self.extract_features(x) 142 | 143 | @property 144 | def dim_latent(self): 145 | return self.clip_model.text_projection.shape[-1] 146 | 147 | def output_shape(self): 148 | return { 149 | name: ShapeSpec( 150 | channels=self._out_feature_channels[name], 151 | stride=self._out_feature_strides[name], 152 | ) 153 | for name in ["stem", "res2", "res3", "res4", "res5", "clip_embedding"] 154 | } 155 | 156 | @property 157 | def size_divisibility(self): 158 | return -1 159 | -------------------------------------------------------------------------------- /lib/models/framework/__init__.py: -------------------------------------------------------------------------------- 1 | from .uni_ovseg import UniOVSeg_S1 2 | -------------------------------------------------------------------------------- /lib/models/framework/uni_ovseg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | from detectron2.config import configurable 6 | from detectron2.modeling import META_ARCH_REGISTRY 7 | from detectron2.modeling.backbone import Backbone 8 | from detectron2.structures import ImageList 9 | from detectron2.utils.memory import retry_if_cuda_oom 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from ..backbone.clip import build_main_backbone 14 | from ...utils import ( 15 | mask_nms, 16 | sem_seg_postprocess, 17 | ) 18 | from ..pixel_decoder import build_pixel_decoder 19 | from ..transformer_decoder import build_transformer_decoder 20 | from ..utils import MaskPooling 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @META_ARCH_REGISTRY.register() 26 | class UniOVSeg_S1(nn.Module): 27 | @configurable 28 | def __init__( 29 | self, 30 | *, 31 | backbone: Backbone, 32 | pixel_decoder: nn.Module, 33 | mask_decoder: nn.Module, 34 | size_divisibility: int, 35 | pixel_mean: Tuple[float], 36 | pixel_std: Tuple[float], 37 | pts_per_side_test: int, 38 | input_size: Tuple[int], 39 | autolabel_type: str = "panoptic-point", 40 | sem_seg_postprocess_before_inference: bool, 41 | ): 42 | super().__init__() 43 | # architecture 44 | self.backbone = backbone 45 | self.sem_seg_head = nn.ModuleDict( 46 | { 47 | "pixel_decoder": pixel_decoder, 48 | "predictor": mask_decoder, 49 | } 50 | ) 51 | self.mask_pooling = MaskPooling() 52 | 53 | # utils 54 | if size_divisibility < 0: 55 | # use backbone size_divisibility if not set 56 | size_divisibility = self.backbone.size_divisibility 57 | self.size_divisibility = size_divisibility 58 | self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference 59 | 60 | # test settings 61 | autolabel_type = autolabel_type.split("-") 62 | self.autolabel_type = autolabel_type[0] 63 | self.prompt_type = autolabel_type[1] 64 | 65 | # image statistics 66 | self.register_buffer( 67 | name="pixel_mean", 68 | tensor=torch.Tensor(pixel_mean).view(-1, 1, 1), 69 | persistent=False, 70 | ) 71 | self.register_buffer( 72 | name="pixel_std", 73 | tensor=torch.Tensor(pixel_std).view(-1, 1, 1), 74 | persistent=False, 75 | ) 76 | 77 | # point prompts 78 | self.pts_per_side_test = pts_per_side_test 79 | self.input_size = input_size 80 | 81 | @classmethod 82 | def from_config(cls, cfg): 83 | backbone = build_main_backbone(cfg) 84 | pixel_decoder = build_pixel_decoder(cfg, backbone.output_shape()) 85 | mask_decoder = build_transformer_decoder(cfg) 86 | return { 87 | "backbone": backbone, 88 | "pixel_decoder": pixel_decoder, 89 | "mask_decoder": mask_decoder, 90 | "size_divisibility": cfg.MODEL.OVSEG.SIZE_DIVISIBILITY, 91 | "pixel_mean": cfg.MODEL.PIXEL_MEAN, 92 | "pixel_std": cfg.MODEL.PIXEL_STD, 93 | # inference 94 | "pts_per_side_test": cfg.MODEL.OVSEG.TEST.PTS_PER_SIDE, 95 | "input_size": (cfg.INPUT.CROP_SIZE, cfg.INPUT.CROP_SIZE), 96 | "autolabel_type": cfg.MODEL.OVSEG.TEST.AUTOLABEL_TYPE, 97 | "sem_seg_postprocess_before_inference": cfg.MODEL.OVSEG.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE, 98 | } 99 | 100 | @property 101 | def device(self): 102 | return self.pixel_mean.device 103 | 104 | def preprosses_image(self, batched_inputs): 105 | images = [x["image"].to(self.device) for x in batched_inputs] 106 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 107 | images = ImageList.from_tensors(images, self.size_divisibility) 108 | sizes = torch.tensor( 109 | [ 110 | [ 111 | x.get("width", self.input_size[0]) - 1, 112 | x.get("height", self.input_size[1]) - 1, 113 | ] 114 | for x in batched_inputs 115 | ], 116 | device=self.device, 117 | ) 118 | aspect_ratio = (self.input_size[0] - 1.0) / sizes.amax(dim=-1) 119 | aspect_ratio = aspect_ratio.unsqueeze(1) 120 | scaled_sizes = sizes * aspect_ratio 121 | sizes = sizes + 1 122 | return images, sizes, scaled_sizes 123 | 124 | @staticmethod 125 | def prepare_points( 126 | pts_per_side: Tuple[int, int], scaled_size: torch.Tensor, device: torch.device 127 | ): 128 | pts_side_x, pts_side_y = pts_per_side 129 | offset_x = 1 / (2 * pts_side_x) 130 | offset_y = 1 / (2 * pts_side_y) 131 | pts_x_side = torch.linspace(offset_x, 1 - offset_x, pts_side_x) 132 | pts_y_side = torch.linspace(offset_y, 1 - offset_y, pts_side_y) 133 | pts_x, pts_y = torch.meshgrid(pts_x_side, pts_y_side, indexing="xy") 134 | pts_grid = torch.stack([pts_x, pts_y], dim=-1).reshape(-1, 2) 135 | pts_grid = pts_grid.to(device) 136 | # scale to image size 137 | pts_grid = pts_grid.unsqueeze(0) * scaled_size.unsqueeze(1) 138 | pts_grid = torch.cat([pts_grid - 2.0, pts_grid + 2.0], dim=-1) 139 | return pts_grid 140 | 141 | def pts_test_forward( 142 | self, 143 | features: Dict, 144 | scaled_sizes: torch.Tensor, 145 | ): 146 | 147 | ( 148 | mask_features, 149 | transformer_encoder_features, 150 | multi_scale_features, 151 | ) = self.sem_seg_head["pixel_decoder"].forward_features(features) 152 | 153 | # slice forward for memory-efficient inference 154 | points = self.prepare_points(self.pts_per_side_test, scaled_sizes, self.device) 155 | points = torch.split(points.squeeze(0), 100, dim=0) 156 | mask_pred_results, iou_pred_results = [], [] 157 | for point in points: 158 | output = self.sem_seg_head["predictor"]( 159 | multi_scale_features, 160 | mask_features, 161 | points=[point], 162 | boxes=None, 163 | points_multi=None, 164 | ) 165 | mask_pred_results.append(output["pred_masks"].flatten(1, 2)) 166 | iou_pred_results.append(output["pred_ious"].flatten(1, 2)) 167 | del points, point, output 168 | mask_pred_results = torch.cat(mask_pred_results, dim=1) 169 | iou_pred_results = torch.cat(iou_pred_results, dim=1) 170 | 171 | return mask_pred_results, iou_pred_results 172 | 173 | def box_test_forward( 174 | self, 175 | batched_inputs: List, 176 | features: Dict, 177 | ): 178 | boxes = [x["instances"].gt_boxes.to(self.device) for x in batched_inputs] 179 | ( 180 | mask_features, 181 | transformer_encoder_features, 182 | multi_scale_features, 183 | ) = self.sem_seg_head["pixel_decoder"].forward_features(features) 184 | outputs = self.sem_seg_head["predictor"]( 185 | multi_scale_features, 186 | mask_features, 187 | points=None, 188 | boxes=boxes, 189 | points_multi=None, 190 | ) 191 | mask_pred_results = outputs["pred_masks"].flatten(1, 2) # N, QK, H, W 192 | iou_pred_results = outputs["pred_ious"].flatten(1, 2) # N, QK 193 | 194 | return mask_pred_results, iou_pred_results 195 | 196 | def forward(self, batched_inputs): 197 | """ 198 | Args: 199 | batched_inputs: a list, batched outputs of :class:`DatasetMapper`. 200 | Each item in the list contains the inputs for one image. 201 | For now, each item in the list is a dict that contains: 202 | * "image": Tensor, image in (C, H, W) format. 203 | * "instances": per-region ground truth 204 | * Other information that's included in the original dicts, such as: 205 | "height", "width" (int): the output resolution of the model (may be different 206 | from input resolution), used in inference. 207 | Returns: 208 | list[dict]: 209 | each dict has the results for one image. The dict contains the following keys: 210 | 211 | * "sem_seg": 212 | A Tensor that represents the 213 | per-pixel segmentation prediced by the head. 214 | The prediction has shape KxHxW that represents the logits of 215 | each class for each pixel. 216 | * "panoptic_seg": 217 | A tuple that represent panoptic output 218 | panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. 219 | segments_info (list[dict]): Describe each segment in `panoptic_seg`. 220 | Each dict contains keys "id", "category_id", "isthing". 221 | """ 222 | # prepare image 223 | images, output_sizes, scaled_sizes = self.preprosses_image(batched_inputs) 224 | # feature extraction via backbone 225 | features = self.backbone(images.tensor) 226 | # mask prediction 227 | if self.prompt_type == "point": 228 | mask_pred_results, iou_pred_results = self.pts_test_forward( 229 | features, scaled_sizes 230 | ) 231 | elif self.prompt_type == "box": 232 | mask_pred_results, iou_pred_results = self.box_test_forward( 233 | batched_inputs, features 234 | ) 235 | else: 236 | raise NotImplementedError( 237 | f"Visual prompt type {self.prompt_type} is not supported" 238 | ) 239 | # post-process 240 | iou_pred_results = iou_pred_results.sigmoid() 241 | # upsample masks 242 | mask_pred_results = F.interpolate( 243 | mask_pred_results, 244 | size=self.input_size, 245 | mode="bilinear", 246 | align_corners=False, 247 | ) 248 | processed_results = [] 249 | for mask_pred_result, iou_pred_result, output_size in zip( 250 | mask_pred_results, iou_pred_results, output_sizes 251 | ): 252 | processed_results.append({}) 253 | # drop low iou predictions 254 | keep = iou_pred_result.ge(0.65) 255 | iou_pred_result = iou_pred_result[keep] 256 | mask_pred_result = mask_pred_result[keep] 257 | # drop redundant mask via mask nms 258 | keep = mask_nms( 259 | masks=mask_pred_result, 260 | scores=iou_pred_result, 261 | iou_threshold=0.5, 262 | inner_threshold=0.7, 263 | nms_type="inner-nms", 264 | downsample=0.5, 265 | ) 266 | iou_pred_result = iou_pred_result[keep] 267 | mask_pred_result = mask_pred_result[keep] 268 | # semseg post-processing 269 | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( 270 | mask_pred_result, self.input_size, output_size 271 | ) 272 | processed_results[-1]["proposal"] = mask_pred_result 273 | torch.cuda.empty_cache() 274 | return processed_results 275 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .msdeformattn import build_pixel_decoder -------------------------------------------------------------------------------- /lib/models/pixel_decoder/msdeformattn.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Callable, Dict, List, Optional, Tuple, Union 3 | 4 | import fvcore.nn.weight_init as weight_init 5 | import numpy as np 6 | import torch 7 | from detectron2.config import configurable 8 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 9 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import functional as F 13 | from torch.nn.init import normal_ 14 | 15 | from ..transformer_decoder.position_encoding import ( 16 | PositionEmbeddingRandom, 17 | PositionEmbeddingSine, 18 | ) 19 | from .ops.modules import MSDeformAttn 20 | from ..utils import LayerNorm2d 21 | 22 | 23 | def _get_clones(module, N): 24 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 25 | 26 | 27 | def _get_activation_fn(activation): 28 | """Return an activation function given a string""" 29 | if activation == "relu": 30 | return F.relu 31 | if activation == "gelu": 32 | return F.gelu 33 | if activation == "glu": 34 | return F.glu 35 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 36 | 37 | 38 | def build_pixel_decoder(cfg, input_shape): 39 | """ 40 | Build a pixel decoder from `cfg.MODEL.ONE_FORMER.PIXEL_DECODER_NAME`. 41 | """ 42 | name = cfg.MODEL.OVSEG.PIXEL_DECODER_NAME 43 | model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) 44 | forward_features = getattr(model, "forward_features", None) 45 | if not callable(forward_features): 46 | raise ValueError( 47 | "Only model with forward_features method can be used as pixel decoder. " 48 | f"Please implement forward_features for {name} to only return mask features." 49 | ) 50 | return model 51 | 52 | 53 | # MSDeformAttn Transformer encoder in deformable detr 54 | class MSDeformAttnTransformerEncoderOnly(nn.Module): 55 | def __init__( 56 | self, 57 | d_model=256, 58 | nhead=8, 59 | num_encoder_layers=6, 60 | dim_feedforward=1024, 61 | dropout=0.1, 62 | activation="relu", 63 | num_feature_levels=4, 64 | enc_n_points=4, 65 | ): 66 | super().__init__() 67 | 68 | self.d_model = d_model 69 | self.nhead = nhead 70 | 71 | encoder_layer = MSDeformAttnTransformerEncoderLayer( 72 | d_model, 73 | dim_feedforward, 74 | dropout, 75 | activation, 76 | num_feature_levels, 77 | nhead, 78 | enc_n_points, 79 | ) 80 | self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) 81 | 82 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) 83 | 84 | self._reset_parameters() 85 | 86 | def _reset_parameters(self): 87 | for p in self.parameters(): 88 | if p.dim() > 1: 89 | nn.init.xavier_uniform_(p) 90 | for m in self.modules(): 91 | if isinstance(m, MSDeformAttn): 92 | # if isinstance(m, FlashDeformAttn): 93 | m._reset_parameters() 94 | normal_(self.level_embed) 95 | 96 | def get_valid_ratio(self, mask): 97 | _, H, W = mask.shape 98 | valid_H = torch.sum(~mask[:, :, 0], 1) 99 | valid_W = torch.sum(~mask[:, 0, :], 1) 100 | valid_ratio_h = valid_H.float() / H 101 | valid_ratio_w = valid_W.float() / W 102 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 103 | return valid_ratio 104 | 105 | def forward(self, srcs, pos_embeds): 106 | masks = [ 107 | torch.zeros( 108 | (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool 109 | ) 110 | for x in srcs 111 | ] 112 | # prepare input for encoder 113 | src_flatten = [] 114 | mask_flatten = [] 115 | lvl_pos_embed_flatten = [] 116 | spatial_shapes = [] 117 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 118 | bs, c, h, w = src.shape 119 | spatial_shape = (h, w) 120 | spatial_shapes.append(spatial_shape) 121 | src = src.flatten(2).transpose(1, 2) 122 | mask = mask.flatten(1) 123 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 124 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 125 | lvl_pos_embed_flatten.append(lvl_pos_embed) 126 | src_flatten.append(src) 127 | mask_flatten.append(mask) 128 | src_flatten = torch.cat(src_flatten, 1) 129 | mask_flatten = torch.cat(mask_flatten, 1) 130 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 131 | spatial_shapes = torch.as_tensor( 132 | spatial_shapes, dtype=torch.long, device=src_flatten.device 133 | ) 134 | level_start_index = torch.cat( 135 | (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) 136 | ) 137 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 138 | 139 | # encoder 140 | memory = self.encoder( 141 | src_flatten, 142 | spatial_shapes, 143 | level_start_index, 144 | valid_ratios, 145 | lvl_pos_embed_flatten, 146 | mask_flatten, 147 | ) 148 | 149 | return memory, spatial_shapes, level_start_index 150 | 151 | 152 | class MSDeformAttnTransformerEncoderLayer(nn.Module): 153 | def __init__( 154 | self, 155 | d_model=256, 156 | d_ffn=1024, 157 | dropout=0.1, 158 | activation="relu", 159 | n_levels=4, 160 | n_heads=8, 161 | n_points=4, 162 | ): 163 | super().__init__() 164 | 165 | # self attention 166 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 167 | self.dropout1 = nn.Dropout(dropout) 168 | self.norm1 = nn.LayerNorm(d_model) 169 | 170 | # ffn 171 | self.linear1 = nn.Linear(d_model, d_ffn) 172 | self.activation = _get_activation_fn(activation) 173 | self.dropout2 = nn.Dropout(dropout) 174 | self.linear2 = nn.Linear(d_ffn, d_model) 175 | self.dropout3 = nn.Dropout(dropout) 176 | self.norm2 = nn.LayerNorm(d_model) 177 | 178 | @staticmethod 179 | def with_pos_embed(tensor, pos): 180 | return tensor if pos is None else tensor + pos 181 | 182 | def forward_ffn(self, src): 183 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 184 | src = src + self.dropout3(src2) 185 | src = self.norm2(src) 186 | return src 187 | 188 | def forward( 189 | self, 190 | src, 191 | pos, 192 | reference_points, 193 | spatial_shapes, 194 | level_start_index, 195 | padding_mask=None, 196 | ): 197 | # self attention 198 | src2 = self.self_attn( 199 | self.with_pos_embed(src, pos), 200 | reference_points, 201 | src, 202 | spatial_shapes, 203 | level_start_index, 204 | padding_mask, 205 | ) 206 | src = src + self.dropout1(src2) 207 | src = self.norm1(src) 208 | 209 | # ffn 210 | src = self.forward_ffn(src) 211 | 212 | return src 213 | 214 | 215 | class MSDeformAttnTransformerEncoder(nn.Module): 216 | def __init__(self, encoder_layer, num_layers): 217 | super().__init__() 218 | self.layers = _get_clones(encoder_layer, num_layers) 219 | self.num_layers = num_layers 220 | 221 | @staticmethod 222 | def get_reference_points(spatial_shapes, valid_ratios, device): 223 | reference_points_list = [] 224 | for lvl, (H_, W_) in enumerate(spatial_shapes): 225 | 226 | ref_y, ref_x = torch.meshgrid( 227 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 228 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), 229 | ) 230 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) 231 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) 232 | ref = torch.stack((ref_x, ref_y), -1) 233 | reference_points_list.append(ref) 234 | reference_points = torch.cat(reference_points_list, 1) 235 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 236 | return reference_points 237 | 238 | def forward( 239 | self, 240 | src, 241 | spatial_shapes, 242 | level_start_index, 243 | valid_ratios, 244 | pos=None, 245 | padding_mask=None, 246 | ): 247 | output = src 248 | reference_points = self.get_reference_points( 249 | spatial_shapes, valid_ratios, device=src.device 250 | ) 251 | for _, layer in enumerate(self.layers): 252 | output = layer( 253 | output, 254 | pos, 255 | reference_points, 256 | spatial_shapes, 257 | level_start_index, 258 | padding_mask, 259 | ) 260 | 261 | return output 262 | 263 | 264 | @SEM_SEG_HEADS_REGISTRY.register() 265 | class MSDeformAttnPixelDecoder(nn.Module): 266 | @configurable 267 | def __init__( 268 | self, 269 | input_shape: Dict[str, ShapeSpec], 270 | *, 271 | transformer_dropout: float, 272 | transformer_nheads: int, 273 | transformer_dim_feedforward: int, 274 | transformer_enc_layers: int, 275 | conv_dim: int, 276 | mask_dim: int, 277 | norm: Optional[Union[str, Callable]] = None, 278 | # deformable transformer encoder args 279 | transformer_in_features: List[str], 280 | common_stride: int, 281 | ): 282 | """ 283 | NOTE: this interface is experimental. 284 | Args: 285 | input_shape: shapes (channels and stride) of the input features 286 | transformer_dropout: dropout probability in transformer 287 | transformer_nheads: number of heads in transformer 288 | transformer_dim_feedforward: dimension of feedforward network 289 | transformer_enc_layers: number of transformer encoder layers 290 | conv_dims: number of output channels for the intermediate conv layers. 291 | mask_dim: number of output channels for the final conv layer. 292 | norm (str or callable): normalization for all conv layers 293 | """ 294 | super().__init__() 295 | transformer_input_shape = { 296 | k: v for k, v in input_shape.items() if k in transformer_in_features 297 | } 298 | 299 | # this is the input shape of pixel decoder 300 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 301 | self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" 302 | self.feature_strides = [v.stride for k, v in input_shape] 303 | self.feature_channels = [v.channels for k, v in input_shape] 304 | 305 | # this is the input shape of transformer encoder (could use less features than pixel decoder 306 | transformer_input_shape = sorted( 307 | transformer_input_shape.items(), key=lambda x: x[1].stride 308 | ) 309 | self.transformer_in_features = [ 310 | k for k, v in transformer_input_shape 311 | ] # starting from "res2" to "res5" 312 | transformer_in_channels = [v.channels for k, v in transformer_input_shape] 313 | self.transformer_feature_strides = [ 314 | v.stride for k, v in transformer_input_shape 315 | ] # to decide extra FPN layers 316 | 317 | self.transformer_num_feature_levels = len(self.transformer_in_features) 318 | if self.transformer_num_feature_levels > 1: 319 | input_proj_list = [] 320 | # from low resolution to high resolution (res5 -> res2) 321 | for in_channels in transformer_in_channels[::-1]: 322 | input_proj_list.append( 323 | nn.Sequential( 324 | nn.Conv2d(in_channels, conv_dim, kernel_size=1), 325 | # nn.GroupNorm(32, conv_dim), 326 | LayerNorm2d(conv_dim), 327 | ) 328 | ) 329 | self.input_proj = nn.ModuleList(input_proj_list) 330 | else: 331 | self.input_proj = nn.ModuleList( 332 | [ 333 | nn.Sequential( 334 | nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), 335 | # nn.GroupNorm(32, conv_dim), 336 | LayerNorm2d(conv_dim), 337 | ) 338 | ] 339 | ) 340 | 341 | for proj in self.input_proj: 342 | nn.init.xavier_uniform_(proj[0].weight, gain=1) 343 | nn.init.constant_(proj[0].bias, 0) 344 | 345 | self.transformer = MSDeformAttnTransformerEncoderOnly( 346 | d_model=conv_dim, 347 | dropout=transformer_dropout, 348 | nhead=transformer_nheads, 349 | dim_feedforward=transformer_dim_feedforward, 350 | num_encoder_layers=transformer_enc_layers, 351 | num_feature_levels=self.transformer_num_feature_levels, 352 | ) 353 | N_steps = conv_dim // 2 354 | # self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 355 | self.pe_layer = PositionEmbeddingRandom(N_steps) 356 | 357 | self.mask_dim = mask_dim 358 | # use 1x1 conv instead 359 | self.mask_features = Conv2d( 360 | conv_dim, 361 | mask_dim, 362 | kernel_size=1, 363 | stride=1, 364 | padding=0, 365 | ) 366 | weight_init.c2_xavier_fill(self.mask_features) 367 | 368 | self.maskformer_num_feature_levels = 3 # always use 3 scales 369 | self.common_stride = common_stride 370 | 371 | # extra fpn levels 372 | stride = min(self.transformer_feature_strides) 373 | self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) 374 | 375 | lateral_convs = [] 376 | output_convs = [] 377 | 378 | use_bias = norm == "" 379 | for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]): 380 | lateral_norm = get_norm(norm, conv_dim) 381 | output_norm = get_norm(norm, conv_dim) 382 | 383 | lateral_conv = Conv2d( 384 | in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm 385 | ) 386 | output_conv = Conv2d( 387 | conv_dim, 388 | conv_dim, 389 | kernel_size=3, 390 | stride=1, 391 | padding=1, 392 | bias=use_bias, 393 | norm=output_norm, 394 | activation=F.relu, 395 | ) 396 | weight_init.c2_xavier_fill(lateral_conv) 397 | weight_init.c2_xavier_fill(output_conv) 398 | self.add_module("adapter_{}".format(idx + 1), lateral_conv) 399 | self.add_module("layer_{}".format(idx + 1), output_conv) 400 | 401 | lateral_convs.append(lateral_conv) 402 | output_convs.append(output_conv) 403 | # Place convs into top-down order (from low to high resolution) 404 | # to make the top-down computation in forward clearer. 405 | self.lateral_convs = lateral_convs[::-1] 406 | self.output_convs = output_convs[::-1] 407 | 408 | def freeze_everything(self): 409 | self.eval() 410 | for param in self.parameters(): 411 | param.requires_grad = False 412 | 413 | @classmethod 414 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 415 | ret = {} 416 | ret["input_shape"] = { 417 | k: v for k, v in input_shape.items() if k in cfg.MODEL.OVSEG.IN_FEATURES 418 | } 419 | ret["conv_dim"] = cfg.MODEL.OVSEG.CONVS_DIM 420 | ret["mask_dim"] = cfg.MODEL.OVSEG.MASK_DIM 421 | ret["norm"] = cfg.MODEL.OVSEG.NORM 422 | ret["transformer_dropout"] = cfg.MODEL.OVSEG.DROPOUT 423 | ret["transformer_nheads"] = cfg.MODEL.OVSEG.NHEADS 424 | ret["transformer_dim_feedforward"] = ( 425 | 1024 # use 1024 for deformable transformer encoder 426 | ) 427 | ret["transformer_enc_layers"] = ( 428 | cfg.MODEL.OVSEG.TRANSFORMER_ENC_LAYERS 429 | ) # a separate config 430 | ret["transformer_in_features"] = ( 431 | cfg.MODEL.OVSEG.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES 432 | ) 433 | ret["common_stride"] = cfg.MODEL.OVSEG.COMMON_STRIDE 434 | return ret 435 | 436 | @autocast(enabled=False) 437 | def forward_features(self, features): 438 | srcs = [] 439 | pos = [] 440 | # Reverse feature maps into top-down order (from low to high resolution) 441 | for idx, f in enumerate(self.transformer_in_features[::-1]): 442 | x = features[f].float() # deformable detr does not support half precision 443 | srcs.append(self.input_proj[idx](x)) 444 | pos.append(self.pe_layer(x)) 445 | 446 | y, spatial_shapes, level_start_index = self.transformer(srcs, pos) 447 | bs = y.shape[0] 448 | 449 | split_size_or_sections = [None] * self.transformer_num_feature_levels 450 | for i in range(self.transformer_num_feature_levels): 451 | if i < self.transformer_num_feature_levels - 1: 452 | split_size_or_sections[i] = ( 453 | level_start_index[i + 1] - level_start_index[i] 454 | ) 455 | else: 456 | split_size_or_sections[i] = y.shape[1] - level_start_index[i] 457 | y = torch.split(y, split_size_or_sections, dim=1) 458 | 459 | out = [] 460 | multi_scale_features = [] 461 | num_cur_levels = 0 462 | for i, z in enumerate(y): 463 | out.append( 464 | z.transpose(1, 2).view( 465 | bs, -1, spatial_shapes[i][0], spatial_shapes[i][1] 466 | ) 467 | ) 468 | 469 | # append `out` with extra FPN levels 470 | # Reverse feature maps into top-down order (from low to high resolution) 471 | for idx, f in enumerate(self.in_features[: self.num_fpn_levels][::-1]): 472 | x = features[f].float() 473 | lateral_conv = self.lateral_convs[idx] 474 | output_conv = self.output_convs[idx] 475 | cur_fpn = lateral_conv(x) 476 | # Following FPN implementation, we use nearest upsampling here 477 | y = cur_fpn + F.interpolate( 478 | out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False 479 | ) 480 | y = output_conv(y) 481 | out.append(y) 482 | 483 | for o in out: 484 | if num_cur_levels < self.maskformer_num_feature_levels: 485 | multi_scale_features.append(o) 486 | num_cur_levels += 1 487 | 488 | return self.mask_features(out[-1]), out[0], multi_scale_features 489 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd fcclip/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python3 setup.py build install 14 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import warnings 17 | import math 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.nn.init import xavier_uniform_, constant_ 23 | 24 | from ..functions import MSDeformAttnFunction 25 | from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch 26 | 27 | 28 | def _is_power_of_2(n): 29 | if (not isinstance(n, int)) or (n < 0): 30 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 31 | return (n & (n-1) == 0) and n != 0 32 | 33 | 34 | class MSDeformAttn(nn.Module): 35 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 36 | """ 37 | Multi-Scale Deformable Attention Module 38 | :param d_model hidden dimension 39 | :param n_levels number of feature levels 40 | :param n_heads number of attention heads 41 | :param n_points number of sampling points per attention head per feature level 42 | """ 43 | super().__init__() 44 | if d_model % n_heads != 0: 45 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 46 | _d_per_head = d_model // n_heads 47 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 48 | if not _is_power_of_2(_d_per_head): 49 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 50 | "which is more efficient in our CUDA implementation.") 51 | 52 | self.im2col_step = 128 53 | 54 | self.d_model = d_model 55 | self.n_levels = n_levels 56 | self.n_heads = n_heads 57 | self.n_points = n_points 58 | 59 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 60 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 61 | self.value_proj = nn.Linear(d_model, d_model) 62 | self.output_proj = nn.Linear(d_model, d_model) 63 | 64 | self._reset_parameters() 65 | 66 | def _reset_parameters(self): 67 | constant_(self.sampling_offsets.weight.data, 0.) 68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 71 | for i in range(self.n_points): 72 | grid_init[:, :, i, :] *= i + 1 73 | with torch.no_grad(): 74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 75 | constant_(self.attention_weights.weight.data, 0.) 76 | constant_(self.attention_weights.bias.data, 0.) 77 | xavier_uniform_(self.value_proj.weight.data) 78 | constant_(self.value_proj.bias.data, 0.) 79 | xavier_uniform_(self.output_proj.weight.data) 80 | constant_(self.output_proj.bias.data, 0.) 81 | 82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 83 | """ 84 | :param query (N, Length_{query}, C) 85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 91 | 92 | :return output (N, Length_{query}, C) 93 | """ 94 | N, Len_q, _ = query.shape 95 | N, Len_in, _ = input_flatten.shape 96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 97 | 98 | value = self.value_proj(input_flatten) 99 | if input_padding_mask is not None: 100 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 104 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 105 | # N, Len_q, n_heads, n_levels, n_points, 2 106 | if reference_points.shape[-1] == 2: 107 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 108 | sampling_locations = reference_points[:, :, None, :, None, :] \ 109 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 110 | elif reference_points.shape[-1] == 4: 111 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 112 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 113 | else: 114 | raise ValueError( 115 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 116 | try: 117 | output = MSDeformAttnFunction.apply( 118 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 119 | except: 120 | # CPU 121 | output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) 122 | # # For FLOPs calculation only 123 | # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) 124 | output = self.output_proj(output) 125 | return output 126 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | #include "cuda/ms_deform_im2col_cuda.cuh" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | at::Tensor ms_deform_attn_cuda_forward( 26 | const at::Tensor &value, 27 | const at::Tensor &spatial_shapes, 28 | const at::Tensor &level_start_index, 29 | const at::Tensor &sampling_loc, 30 | const at::Tensor &attn_weight, 31 | const int im2col_step) 32 | { 33 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 34 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 35 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 36 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 37 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 38 | 39 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 40 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 41 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 42 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 43 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 44 | 45 | const int batch = value.size(0); 46 | const int spatial_size = value.size(1); 47 | const int num_heads = value.size(2); 48 | const int channels = value.size(3); 49 | 50 | const int num_levels = spatial_shapes.size(0); 51 | 52 | const int num_query = sampling_loc.size(1); 53 | const int num_point = sampling_loc.size(4); 54 | 55 | const int im2col_step_ = std::min(batch, im2col_step); 56 | 57 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 58 | 59 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 60 | 61 | const int batch_n = im2col_step_; 62 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 63 | auto per_value_size = spatial_size * num_heads * channels; 64 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 65 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 66 | for (int n = 0; n < batch/im2col_step_; ++n) 67 | { 68 | auto columns = output_n.select(0, n); 69 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 70 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 71 | value.data() + n * im2col_step_ * per_value_size, 72 | spatial_shapes.data(), 73 | level_start_index.data(), 74 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 75 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 76 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 77 | columns.data()); 78 | 79 | })); 80 | } 81 | 82 | output = output.view({batch, num_query, num_heads*channels}); 83 | 84 | return output; 85 | } 86 | 87 | 88 | std::vector ms_deform_attn_cuda_backward( 89 | const at::Tensor &value, 90 | const at::Tensor &spatial_shapes, 91 | const at::Tensor &level_start_index, 92 | const at::Tensor &sampling_loc, 93 | const at::Tensor &attn_weight, 94 | const at::Tensor &grad_output, 95 | const int im2col_step) 96 | { 97 | 98 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 99 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 100 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 101 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 102 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 103 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 104 | 105 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 106 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 107 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 108 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 109 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 110 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 111 | 112 | const int batch = value.size(0); 113 | const int spatial_size = value.size(1); 114 | const int num_heads = value.size(2); 115 | const int channels = value.size(3); 116 | 117 | const int num_levels = spatial_shapes.size(0); 118 | 119 | const int num_query = sampling_loc.size(1); 120 | const int num_point = sampling_loc.size(4); 121 | 122 | const int im2col_step_ = std::min(batch, im2col_step); 123 | 124 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 125 | 126 | auto grad_value = at::zeros_like(value); 127 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 128 | auto grad_attn_weight = at::zeros_like(attn_weight); 129 | 130 | const int batch_n = im2col_step_; 131 | auto per_value_size = spatial_size * num_heads * channels; 132 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 133 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 134 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 135 | 136 | for (int n = 0; n < batch/im2col_step_; ++n) 137 | { 138 | auto grad_output_g = grad_output_n.select(0, n); 139 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 140 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 141 | grad_output_g.data(), 142 | value.data() + n * im2col_step_ * per_value_size, 143 | spatial_shapes.data(), 144 | level_start_index.data(), 145 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 146 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 147 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 148 | grad_value.data() + n * im2col_step_ * per_value_size, 149 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 150 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 151 | 152 | })); 153 | } 154 | 155 | return { 156 | grad_value, grad_sampling_loc, grad_attn_weight 157 | }; 158 | } -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /lib/models/pixel_decoder/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /lib/models/prompt_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import build_prompt_encoder -------------------------------------------------------------------------------- /lib/models/prompt_decoder/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | from detectron2.config import configurable 6 | from detectron2.utils.registry import Registry 7 | from ..transformer_decoder.position_encoding import ( 8 | PositionEmbeddingSine, 9 | PositionEmbeddingRandom, 10 | ) 11 | 12 | 13 | PROMPT_ENCODER_REGISTRY = Registry("PROMPT_ENCODER") 14 | PROMPT_ENCODER_REGISTRY.__doc__ = """ 15 | Registry for prompt encoder in Uni-OVSeg. 16 | """ 17 | 18 | 19 | def build_prompt_encoder(cfg): 20 | """ 21 | Build a prompt encoder from `cfg.MODEL.OVSEG.PROMPT_ENCODER_NAME`. 22 | """ 23 | name = cfg.MODEL.OVSEG.PROMPT_ENCODER_NAME 24 | model = PROMPT_ENCODER_REGISTRY.get(name)(cfg) 25 | return model 26 | 27 | 28 | @PROMPT_ENCODER_REGISTRY.register() 29 | class PromptEncoder(nn.Module): 30 | @configurable 31 | def __init__(self, embed_dim: int, image_size: List[int], num_masks: int = 4): 32 | super().__init__() 33 | self.embed_dim = embed_dim 34 | self.image_size = image_size 35 | # position embedding 36 | # self.pos_emb = PositionEmbeddingSine(embed_dim // 2, normalize=True) 37 | self.pos_emb = PositionEmbeddingRandom(embed_dim // 2) 38 | # corner embedding: left top, right bottom 39 | self.corner_emb = nn.Parameter(torch.randn(1, 2, embed_dim)) 40 | nn.init.normal_(self.corner_emb) 41 | # type embedding: point, box 42 | self.point_emb = nn.Parameter(torch.randn(1, 1, embed_dim)) 43 | nn.init.normal_(self.point_emb) 44 | self.box_emb = nn.Parameter(torch.randn(1, 1, embed_dim)) 45 | nn.init.normal_(self.box_emb) 46 | # attribute embedding: positive, negative 47 | self.attr_emb = nn.Embedding(2, embed_dim) 48 | nn.init.normal_(self.attr_emb.weight) 49 | # mask embedding: num_masks mask proposals 50 | self.mask_emb = nn.Parameter(torch.randn(1, num_masks + 1, embed_dim)) 51 | nn.init.normal_(self.mask_emb) 52 | 53 | def freeze_everything(self): 54 | self.eval() 55 | for param in self.parameters(): 56 | param.requires_grad = False 57 | 58 | @classmethod 59 | def from_config(cls, cfg): 60 | ret = {} 61 | ret["embed_dim"] = cfg.MODEL.OVSEG.EMBED_DIM 62 | ret["image_size"] = [cfg.INPUT.CROP_SIZE, cfg.INPUT.CROP_SIZE] 63 | ret["num_masks"] = cfg.MODEL.OVSEG.NUM_MASKS 64 | return ret 65 | 66 | def _embed_point(self, box: torch.Tensor, feat: torch.Tensor): 67 | N = len(box) 68 | if hasattr(box, "tensor"): 69 | box = box.tensor + 0.5 70 | box_embed = self.pos_emb.forward_with_coords( 71 | box.reshape(N, 2, 2), self.image_size 72 | ) 73 | attr_embed = self.attr_emb(torch.ones_like(box_embed[:, :, 0]).long()) 74 | corner_embed = self.corner_emb.clone() 75 | point_embed = self.point_emb.clone() 76 | content_embed = feat.unsqueeze(1) 77 | task_embed = box_embed + corner_embed + point_embed + content_embed + attr_embed 78 | 79 | output_embed = self.mask_emb.repeat(N, 1, 1) # 1 x num_masks x C 80 | task_embed = torch.cat( 81 | [task_embed, output_embed], dim=1 82 | ) # N x (2 + num_masks) x C 83 | return task_embed, task_embed 84 | 85 | def _embed_point2( 86 | self, box: torch.Tensor, indicator: torch.Tensor, feat: torch.Tensor 87 | ): 88 | N, P, _ = box.shape 89 | box = box.reshape(N * P, 2, 2) 90 | box_embed = self.pos_emb.forward_with_coords(box, self.image_size) 91 | box_embed = box_embed.reshape(N, P, 2, -1) 92 | attr_embed = self.attr_emb(indicator.long()).unsqueeze(2) 93 | corner_embed = self.corner_emb.clone().unsqueeze(0) 94 | point_embed = self.point_emb.clone().unsqueeze(0) 95 | content_embed = feat.unsqueeze(2) 96 | task_embed = box_embed + corner_embed + point_embed + content_embed + attr_embed 97 | task_embed = task_embed.reshape(N, P * 2, -1) 98 | 99 | # output_embed = self.mask_emb_single.repeat(N, 1, 1) # N x 1 x C 100 | output_embed = self.mask_emb.repeat(N, 1, 1)[:, :1] # N x 1 x C 101 | task_embed = torch.cat([task_embed, output_embed], dim=1) # N x (2P + 1) x C 102 | return task_embed, task_embed 103 | 104 | def _embed_box(self, box: torch.Tensor, feat: torch.Tensor): 105 | N = len(box) 106 | if hasattr(box, "tensor"): 107 | box = box.tensor + 0.5 108 | box_embed = self.pos_emb.forward_with_coords( 109 | box.reshape(N, 2, 2), self.image_size 110 | ) 111 | corner_embed = self.corner_emb.clone() 112 | point_embed = self.box_emb.clone() 113 | content_embed = feat.unsqueeze(1) 114 | task_embed = box_embed + corner_embed + point_embed + content_embed 115 | 116 | output_embed = self.mask_emb.repeat(N, 1, 1) # 1 x num_masks x C 117 | task_embed = torch.cat( 118 | [task_embed, output_embed], dim=1 119 | ) # N x (2 + num_masks) x C 120 | return task_embed, task_embed 121 | 122 | def forward( 123 | self, 124 | points: List[torch.Tensor], 125 | boxes: List[torch.Tensor], 126 | points_multi: Tuple[List], 127 | feats_centers: torch.Tensor, 128 | ): 129 | """This is a forward function of embedding multi-type prompts 130 | 131 | Args: 132 | points (List[torch.Tensor]): A batch of point coordinates. Each one has a shape of [Q, 4]. 133 | feats_centers (torch.Tensor): A batch of center features. It has a shape of [B, Q, C] 134 | Return: 135 | List[torch.Tensor]: Prompt embedding has a shape of [B, Q, K, C] 136 | """ 137 | # embed input prompt into a embedding space 138 | task_emb, pos_emb = [], [] 139 | if points is not None: 140 | for pts, feat in zip(points, feats_centers): 141 | task, pos = self._embed_point(pts, feat) 142 | task_emb.append(task) 143 | pos_emb.append(pos) 144 | if boxes is not None: 145 | for pts, feat in zip(boxes, feats_centers): 146 | task, pos = self._embed_box(pts, feat) 147 | task_emb.append(task) 148 | pos_emb.append(pos) 149 | if points_multi is not None: 150 | for pts, ind, feat in zip(points_multi[0], points_multi[1], feats_centers): 151 | task, pos = self._embed_point2(pts, ind, feat) 152 | task_emb.append(task) 153 | pos_emb.append(pos) 154 | 155 | task_emb = torch.stack(task_emb, dim=0) # [B, Q, K, C] 156 | pos_emb = torch.stack(pos_emb, dim=0) # [B, Q, K, C] 157 | 158 | return task_emb, pos_emb 159 | 160 | 161 | @PROMPT_ENCODER_REGISTRY.register() 162 | class PromptEncoder2(nn.Module): 163 | @configurable 164 | def __init__(self, embed_dim: int, image_size: List[int], num_masks: int = 4): 165 | super().__init__() 166 | self.embed_dim = embed_dim 167 | self.image_size = image_size 168 | # position embedding 169 | self.pos_emb = PositionEmbeddingSine(embed_dim // 2, normalize=True) 170 | # corner embedding: left top, right bottom 171 | self.corner_emb = nn.Parameter(torch.randn(1, 2, embed_dim)) 172 | nn.init.normal_(self.corner_emb) 173 | # type embedding: point, box 174 | self.point_emb = nn.Parameter(torch.randn(1, 1, embed_dim)) 175 | nn.init.normal_(self.point_emb) 176 | self.box_emb = nn.Parameter(torch.randn(1, 1, embed_dim)) 177 | nn.init.normal_(self.box_emb) 178 | # attribute embedding: positive, negative 179 | self.attr_emb = nn.Embedding(2, embed_dim) 180 | nn.init.normal_(self.attr_emb.weight) 181 | # mask embedding: num_masks mask proposals 182 | self.mask_emb = nn.Parameter(torch.randn(1, num_masks + 2, embed_dim)) 183 | nn.init.normal_(self.mask_emb) 184 | 185 | def freeze_everything(self): 186 | self.eval() 187 | for param in self.parameters(): 188 | param.requires_grad = False 189 | 190 | @classmethod 191 | def from_config(cls, cfg): 192 | ret = {} 193 | ret["embed_dim"] = cfg.MODEL.OVSEG.EMBED_DIM 194 | ret["image_size"] = [cfg.INPUT.OVSEG.CROP_SIZE, cfg.INPUT.OVSEG.CROP_SIZE] 195 | ret["num_masks"] = cfg.MODEL.OVSEG.NUM_MASKS 196 | return ret 197 | 198 | def _embed_point(self, point: torch.Tensor, feat: torch.Tensor): 199 | N = len(point) 200 | if hasattr(point, "tensor"): 201 | point = point.tensor + 0.5 202 | point_embed = self.pos_emb.forward_with_coords( 203 | point.reshape(N, 1, 2), self.image_size 204 | ) 205 | attr_embed = self.attr_emb(torch.ones_like(point_embed[:, :, 0]).long()) 206 | type_embed = self.point_emb.clone() 207 | content_embed = feat.unsqueeze(1) 208 | task_embed = point_embed + type_embed + content_embed + attr_embed 209 | 210 | output_embed = self.mask_emb.repeat(N, 1, 1)[:, -2:] # 1 x 4 x C 211 | task_embed = torch.cat([task_embed, output_embed], dim=1) # N x (1 + 4) x C 212 | return task_embed, task_embed 213 | 214 | def _embed_multi_point( 215 | self, points: torch.Tensor, indicator: torch.Tensor, feat: torch.Tensor 216 | ): 217 | N, P, _ = points.shape 218 | points_embed = self.pos_emb.forward_with_coords(points, self.image_size) 219 | attr_embed = self.attr_emb(indicator.long()) 220 | type_embed = self.point_emb.clone() 221 | content_embed = feat.unsqueeze(1) 222 | task_embed = points_embed + type_embed + content_embed + attr_embed 223 | 224 | output_embed = self.mask_emb.repeat(N, 1, 1)[:, 1:2] # N x 1 x C 225 | task_embed = torch.cat([task_embed, output_embed], dim=1) # N x (1 + 1) x C 226 | return task_embed, task_embed 227 | 228 | def _embed_box(self, box: torch.Tensor, feat: torch.Tensor): 229 | N = len(box) 230 | if hasattr(box, "tensor"): 231 | box = box.tensor + 0.5 232 | box_embed = self.pos_emb.forward_with_coords( 233 | box.reshape(N, 2, 2), self.image_size 234 | ) 235 | corner_embed = self.corner_emb.clone() 236 | point_embed = self.box_emb.clone() 237 | content_embed = feat.unsqueeze(1) 238 | task_embed = box_embed + corner_embed + point_embed + content_embed 239 | 240 | output_embed = self.mask_emb.repeat(N, 1, 1)[:, :1] # 1 x 1 x C 241 | task_embed = torch.cat([task_embed, output_embed], dim=1) # N x (2 + 1) x C 242 | return task_embed, task_embed 243 | 244 | def forward( 245 | self, 246 | points: List[torch.Tensor], 247 | boxes: List[torch.Tensor], 248 | points_multi: Tuple[List], 249 | feats_centers: torch.Tensor, 250 | ): 251 | """This is a forward function of embedding multi-type prompts 252 | 253 | Args: 254 | points (List[torch.Tensor]): A batch of point coordinates. Each one has a shape of [Q, 4]. 255 | feats_centers (torch.Tensor): A batch of center features. It has a shape of [B, Q, C] 256 | Return: 257 | List[torch.Tensor]: Prompt embedding has a shape of [B, Q, K, C] 258 | """ 259 | # embed input prompt into a embedding space 260 | task_emb, pos_emb = [], [] 261 | if points is not None: 262 | for pts, feat in zip(points, feats_centers): 263 | task, pos = self._embed_point(pts, feat) 264 | task_emb.append(task) 265 | pos_emb.append(pos) 266 | if boxes is not None: 267 | for pts, feat in zip(boxes, feats_centers): 268 | task, pos = self._embed_box(pts, feat) 269 | task_emb.append(task) 270 | pos_emb.append(pos) 271 | if points_multi is not None: 272 | for pts, ind, feat in zip(points_multi[0], points_multi[1], feats_centers): 273 | task, pos = self._embed_point2(pts, ind, feat) 274 | task_emb.append(task) 275 | pos_emb.append(pos) 276 | 277 | task_emb = torch.stack(task_emb, dim=0) # [B, Q, K, C] 278 | pos_emb = torch.stack(pos_emb, dim=0) # [B, Q, K, C] 279 | 280 | return task_emb, pos_emb 281 | -------------------------------------------------------------------------------- /lib/models/transformer_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .mask_decoder import build_transformer_decoder 2 | from ..utils import MaskPooling -------------------------------------------------------------------------------- /lib/models/transformer_decoder/mask_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from detectron2.config import configurable 5 | from detectron2.layers import Conv2d 6 | from detectron2.utils.registry import Registry 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from ...utils import point_sample 11 | from ..prompt_decoder import build_prompt_encoder 12 | from .position_encoding import PositionEmbeddingRandom 13 | from ..utils import ( 14 | MLP, 15 | CrossAttentionLayer, 16 | FFNLayer, 17 | SelfAttentionLayer, 18 | ) 19 | 20 | 21 | TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE") 22 | TRANSFORMER_DECODER_REGISTRY.__doc__ = """ 23 | Registry for transformer module in MaskFormer. 24 | """ 25 | 26 | 27 | def build_transformer_decoder(cfg): 28 | """ 29 | Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. 30 | """ 31 | name = cfg.MODEL.OVSEG.TRANSFORMER_DECODER_NAME 32 | in_channels = cfg.MODEL.OVSEG.CONVS_DIM 33 | return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels) 34 | 35 | 36 | @TRANSFORMER_DECODER_REGISTRY.register() 37 | class MultiScaleMaskDecoder(nn.Module): 38 | @configurable 39 | def __init__( 40 | self, 41 | in_channels, 42 | *, 43 | embed_dim: int, 44 | prompt_encoder: nn.Module, 45 | nheads: int, 46 | dim_feedforward: int, 47 | dec_layers: int, 48 | pre_norm: bool, 49 | mask_dim: int, 50 | enforce_input_project: bool, 51 | ): 52 | """ 53 | NOTE: this interface is experimental. 54 | Args: 55 | in_channels: channels of the input features 56 | mask_classification: whether to add mask classifier or not 57 | num_classes: number of classes 58 | hidden_dim: Transformer feature dimension 59 | num_queries: number of queries 60 | nheads: number of heads 61 | dim_feedforward: feature dimension in feedforward network 62 | enc_layers: number of Transformer encoder layers 63 | dec_layers: number of Transformer decoder layers 64 | pre_norm: whether to use pre-LayerNorm or not 65 | mask_dim: mask feature dimension 66 | enforce_input_project: add input project 1x1 conv even if input 67 | channels and hidden dim is identical 68 | """ 69 | super().__init__() 70 | # positional encoding 71 | N_steps = embed_dim // 2 72 | # self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 73 | self.pe_layer = PositionEmbeddingRandom(N_steps) 74 | 75 | # define Transformer decoder here 76 | self.mask_dim = mask_dim 77 | self.num_heads = nheads 78 | self.num_layers = dec_layers 79 | self.tgt_mask = None 80 | self.transformer_self_attention_layers = nn.ModuleList() 81 | self.transformer_cross_attention_layers = nn.ModuleList() 82 | self.transformer_ffn_layers = nn.ModuleList() 83 | for _ in range(self.num_layers): 84 | self.transformer_self_attention_layers.append( 85 | SelfAttentionLayer( 86 | d_model=embed_dim, 87 | nhead=nheads, 88 | dropout=0.0, 89 | normalize_before=pre_norm, 90 | ) 91 | ) 92 | 93 | self.transformer_cross_attention_layers.append( 94 | CrossAttentionLayer( 95 | d_model=embed_dim, 96 | nhead=nheads, 97 | dropout=0.0, 98 | normalize_before=pre_norm, 99 | ) 100 | ) 101 | 102 | self.transformer_ffn_layers.append( 103 | FFNLayer( 104 | d_model=embed_dim, 105 | dim_feedforward=dim_feedforward, 106 | dropout=0.0, 107 | normalize_before=pre_norm, 108 | ) 109 | ) 110 | 111 | # extra self-attention layer for learnable query features 112 | self.extra_self_attention_layer = SelfAttentionLayer( 113 | d_model=embed_dim, 114 | nhead=nheads, 115 | dropout=0.0, 116 | normalize_before=pre_norm, 117 | ) 118 | self.feat_embed = nn.Conv2d(embed_dim, mask_dim, 1, bias=True) 119 | 120 | # visual prompt encoder 121 | self.prompt_encoder = prompt_encoder 122 | 123 | # level embedding (we always use 3 scales) 124 | self.num_feature_levels = 3 125 | self.level_embed = nn.Embedding(self.num_feature_levels, embed_dim) 126 | self.input_proj = nn.ModuleList() 127 | for _ in range(self.num_feature_levels): 128 | if in_channels != embed_dim or enforce_input_project: 129 | self.input_proj.append(Conv2d(in_channels, embed_dim, kernel_size=1)) 130 | weight_init.c2_xavier_fill(self.input_proj[-1]) 131 | else: 132 | self.input_proj.append(nn.Sequential()) 133 | 134 | # mask branch 135 | self.decoder_norm = nn.LayerNorm(embed_dim) 136 | self.mask_embed = nn.ModuleList( 137 | [MLP(embed_dim, embed_dim, mask_dim, 3) for _ in range(5)] 138 | ) 139 | self.iou_embed = nn.ModuleList( 140 | [MLP(embed_dim, embed_dim, 1, 3) for _ in range(5)] 141 | ) 142 | 143 | @classmethod 144 | def from_config(cls, cfg, in_channels): 145 | ret = {} 146 | ret["in_channels"] = in_channels 147 | ret["embed_dim"] = cfg.MODEL.OVSEG.EMBED_DIM 148 | # Transformer parameters: 149 | ret["nheads"] = cfg.MODEL.OVSEG.NHEADS 150 | ret["dim_feedforward"] = cfg.MODEL.OVSEG.DIM_FEEDFORWARD 151 | 152 | # NOTE: because we add learnable query features which requires supervision, 153 | # we add minus 1 to decoder layers to be consistent with our loss 154 | # implementation: that is, number of auxiliary losses is always 155 | # equal to number of decoder layers. With learnable query features, the number of 156 | # auxiliary losses equals number of decoders plus 1. 157 | assert cfg.MODEL.OVSEG.DEC_LAYERS >= 1 158 | ret["dec_layers"] = cfg.MODEL.OVSEG.DEC_LAYERS - 1 159 | ret["pre_norm"] = cfg.MODEL.OVSEG.PRE_NORM 160 | ret["enforce_input_project"] = cfg.MODEL.OVSEG.ENFORCE_INPUT_PROJ 161 | ret["mask_dim"] = cfg.MODEL.OVSEG.MASK_DIM 162 | ret["prompt_encoder"] = build_prompt_encoder(cfg) 163 | 164 | return ret 165 | 166 | def get_center(self, points: List[torch.Tensor]): 167 | centers = [] 168 | for point in points: 169 | center = (point[..., 2:] + point[..., :2]) / 2 170 | center[..., 0] /= self.prompt_encoder.image_size[1] 171 | center[..., 1] /= self.prompt_encoder.image_size[0] 172 | centers.append(center) 173 | centers = torch.stack(centers, dim=0) 174 | return centers 175 | 176 | def sample_center(self, centers: torch.Tensor, mask_features: torch.Tensor): 177 | if centers.dim() == 4: 178 | N, Q, K, _ = centers.shape 179 | centers = centers.reshape(N, Q * K, 2) 180 | feature_c = point_sample(mask_features, centers, align_corners=False) 181 | feature_c = feature_c.permute(0, 2, 1).reshape(N, Q, K, -1) # N, Q, K, C 182 | else: 183 | feature_c = point_sample(mask_features, centers, align_corners=False) 184 | feature_c = feature_c.permute(0, 2, 1) # N, Q, C 185 | K = 2 186 | return feature_c, K 187 | 188 | def forward( 189 | self, 190 | x: List[torch.Tensor], 191 | mask_features: torch.Tensor, 192 | points: List[torch.Tensor] = None, 193 | boxes: List[torch.Tensor] = None, 194 | points_multi: List[torch.Tensor] = None, 195 | ): 196 | # x is a list of multi-scale feature 197 | assert len(x) == self.num_feature_levels 198 | src, pos, size_list = [], [], [] 199 | 200 | for i in range(self.num_feature_levels): 201 | size_list.append(x[i].shape[-2:]) 202 | # flatten NxCxHxW to HWxNxC 203 | pos.append(self.pe_layer(x[i], None).flatten(2).permute(2, 0, 1)) 204 | src.append( 205 | self.input_proj[i](x[i]).flatten(2).permute(2, 0, 1) 206 | + self.level_embed.weight[i][None, None, :] 207 | ) 208 | # # flatten NxCxHxW to NxHWxC 209 | # pos.append(self.pe_layer(x[i], None).flatten(2).permute(0, 2, 1)) 210 | # src.append( 211 | # self.input_proj[i](x[i]).flatten(2).permute(0, 2, 1) 212 | # + self.level_embed.weight[i][None, None, :] 213 | # ) 214 | 215 | # calculate centers of points (points is a box with the xyxy format). 216 | # N, Q, 2 / N, Q, K, 2 217 | if points is not None: 218 | # Q, 2 219 | points_c = self.get_center(points) 220 | if boxes is not None: 221 | # Q, 2 222 | points_c = self.get_center(boxes) 223 | if points_multi is not None: 224 | # Q, K, 2 225 | points_c = self.get_center(points_multi[0]) 226 | # sample feature vectors corresponding to centers by grid sample 227 | # N, Q, K, C / N, Q, C 228 | feature_c, K_ = self.sample_center(points_c, mask_features) 229 | 230 | output, query_embed = self.prompt_encoder( 231 | points, boxes, points_multi, feature_c 232 | ) # N, Q, K, C 233 | N, Q, K, C = output.shape 234 | 235 | # prediction heads on learnable query features 236 | predictions_mask = [] 237 | predictions_iou = [] 238 | output = output.reshape(N * Q, K, C).permute(1, 0, 2) 239 | output = self.extra_self_attention_layer( 240 | output, 241 | query_pos=None, 242 | tgt_mask=None, 243 | ) 244 | output = output.permute(1, 0, 2).reshape(N, Q, K, C) 245 | outputs_mask, outputs_iou, attn_mask = self.forward_prediction_heads( 246 | output, 247 | mask_features, 248 | attn_mask_target_size=size_list[0], 249 | multimask=points is not None, 250 | num_prompt=K_, 251 | ) 252 | predictions_mask.append(outputs_mask) 253 | predictions_iou.append(outputs_iou) 254 | 255 | for i in range(self.num_layers): 256 | level_index = i % self.num_feature_levels 257 | attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False 258 | # attention: cross-attention first 259 | output = output.reshape(N, Q * K, C).permute(1, 0, 2) 260 | output = self.transformer_cross_attention_layers[i]( 261 | output, 262 | src[level_index], 263 | memory_mask=attn_mask, 264 | memory_key_padding_mask=None, # here we do not apply masking on padded region 265 | pos=pos[level_index], 266 | query_pos=query_embed.reshape(N, Q * K, C).permute(1, 0, 2), # QK, N, C 267 | ) 268 | output = output.reshape(Q, K, N, C).permute(1, 2, 0, 3).reshape(K, N * Q, C) 269 | # output = output.reshape(N, Q, K, C).reshape(N * Q, K, C) 270 | output = self.transformer_self_attention_layers[i]( 271 | output, 272 | tgt_mask=None, 273 | tgt_key_padding_mask=None, 274 | query_pos=query_embed.reshape(N * Q, K, C).permute(1, 0, 2), 275 | ) 276 | output = self.transformer_ffn_layers[i](output) 277 | output = output.permute(1, 0, 2).reshape(N, Q, K, C) 278 | # output = output.reshape(N, Q, K, C) 279 | 280 | ( 281 | outputs_mask, 282 | outputs_iou, 283 | attn_mask, 284 | ) = self.forward_prediction_heads( 285 | output, 286 | mask_features, 287 | attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], 288 | multimask=points is not None, 289 | num_prompt=K_, 290 | ) 291 | predictions_mask.append(outputs_mask) 292 | predictions_iou.append(outputs_iou) 293 | 294 | assert len(predictions_mask) == self.num_layers + 1 295 | 296 | out = { 297 | "pred_masks": predictions_mask[-1], 298 | "pred_ious": predictions_iou[-1], 299 | "aux_outputs": self._set_aux_loss(predictions_mask, predictions_iou), 300 | } 301 | return out 302 | 303 | def forward_prediction_heads( 304 | self, 305 | output, 306 | mask_features, 307 | attn_mask_target_size, 308 | multimask=False, 309 | num_prompt=None, 310 | ): 311 | N, Q, K, _ = output.shape 312 | _, C, H, W = mask_features.shape 313 | mask_features = self.feat_embed(mask_features) 314 | decoder_output = self.decoder_norm(output) 315 | decoder_output = decoder_output[:, :, num_prompt:, :] 316 | 317 | if multimask: 318 | K_ = K - (num_prompt + 1) 319 | decoder_output = decoder_output[:, :, 1:, :] 320 | mask_embed = torch.stack( 321 | [self.mask_embed[i + 1](decoder_output[:, :, i]) for i in range(K_)], 322 | dim=2, 323 | ) 324 | iou_embed = torch.stack( 325 | [self.iou_embed[i + 1](decoder_output[:, :, i]) for i in range(K_)], 326 | dim=2, 327 | ) 328 | else: 329 | K_ = 1 330 | mask_embed = self.mask_embed[0](decoder_output[:, :, :1, :]) 331 | iou_embed = self.iou_embed[0](decoder_output[:, :, :1, :]) 332 | 333 | # mask branch 334 | outputs_mask = torch.einsum("bqkc,bchw->bqkhw", mask_embed, mask_features) 335 | outputs_iou = iou_embed.squeeze(-1).sigmoid() 336 | 337 | # NOTE: prediction is of higher-resolution 338 | # [B, Q, K, H, W] -> [B, QK, H, W] -> [B, Q, K, H*W] 339 | attn_mask = outputs_mask.reshape(N, Q * K_, H, W) 340 | attn_mask = F.interpolate( 341 | attn_mask, 342 | size=attn_mask_target_size, 343 | mode="bilinear", 344 | align_corners=False, 345 | ).reshape(N, Q, K_, attn_mask_target_size[0], attn_mask_target_size[1]) 346 | 347 | # must use bool type 348 | # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. 349 | # [B, Q, K, H, W] -> [B, Q, 1, H, W] -> [B, 1, Q, 1, H, W] -> [B, 1, Q, 1, H*W] -> [B, h, Q, K, H*W] -> [B*h, Q*K, H*W] 350 | attn_mask = ( 351 | attn_mask.sigmoid() 352 | .ge(0.5) 353 | .sum(dim=2, keepdim=True) 354 | .bool() 355 | .unsqueeze(1) 356 | .flatten(-2) 357 | .repeat(1, self.num_heads, 1, K, 1) 358 | .reshape( 359 | N * self.num_heads, 360 | Q * K, 361 | attn_mask_target_size[0] * attn_mask_target_size[1], 362 | ) 363 | ).detach() 364 | attn_mask = ~attn_mask 365 | 366 | return outputs_mask, outputs_iou, attn_mask 367 | 368 | @torch.jit.unused 369 | def _set_aux_loss(self, predictions_mask, predictions_iou): 370 | # this is a workaround to make torchscript happy, as torchscript 371 | # doesn't support dictionary with non-homogeneous values, such 372 | # as a dict having both a Tensor and a list. 373 | return [ 374 | {"pred_masks": mask, "pred_ious": iou} 375 | for mask, iou in zip(predictions_mask[:-1], predictions_iou[:-1]) 376 | ] 377 | -------------------------------------------------------------------------------- /lib/models/transformer_decoder/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class PositionEmbeddingRandom(nn.Module): 10 | """ 11 | Positional encoding using random spatial frequencies. 12 | Based on my personal understanding, random frequencies are used to 13 | simulate inaccurate box inputs from users. 14 | """ 15 | 16 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 17 | super().__init__() 18 | if scale is None or scale <= 0.0: 19 | scale = 1.0 20 | self.register_buffer( 21 | "positional_encoding_gaussian_matrix", 22 | scale * torch.randn((2, num_pos_feats)), 23 | ) 24 | 25 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 26 | """Positionally encode points that are normalized to [0,1].""" 27 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 28 | coords = 2 * coords - 1 # -1 ~ 1 29 | coords = coords @ self.positional_encoding_gaussian_matrix 30 | coords = 2 * np.pi * coords 31 | # outputs d_1 x ... x d_n x C shape 32 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 33 | 34 | def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: 35 | """Generate positional encoding for a grid of the specified size.""" 36 | b, _, h, w = x.shape 37 | device: Any = self.positional_encoding_gaussian_matrix.device 38 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 39 | y_embed = grid.cumsum(dim=0) - 0.5 40 | x_embed = grid.cumsum(dim=1) - 0.5 41 | y_embed = y_embed / h 42 | x_embed = x_embed / w 43 | 44 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 45 | pe = pe.permute(2, 0, 1) # C x H x W 46 | pe = pe.unsqueeze(0).repeat(b, 1, 1, 1) 47 | 48 | return pe 49 | 50 | def forward_with_coords( 51 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 52 | ) -> torch.Tensor: 53 | """Positionally encode points that are not normalized to [0,1].""" 54 | coords = coords_input.clone() 55 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 56 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 57 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 58 | 59 | 60 | class PositionEmbeddingSine(nn.Module): 61 | """ 62 | This is a more standard version of the position embedding, very similar to the one 63 | used by the Attention is all you need paper, generalized to work on images. 64 | """ 65 | 66 | def __init__( 67 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None 68 | ): 69 | super().__init__() 70 | self.num_pos_feats = num_pos_feats 71 | self.temperature = temperature 72 | self.normalize = normalize 73 | if scale is not None and normalize is False: 74 | raise ValueError("normalize should be True if scale is passed") 75 | if scale is None: 76 | scale = 2 * math.pi 77 | self.scale = scale 78 | 79 | def forward(self, x, mask=None): 80 | if mask is None: 81 | mask = torch.zeros( 82 | (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool 83 | ) 84 | not_mask = ~mask 85 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 86 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 87 | if self.normalize: 88 | eps = 1e-6 89 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 90 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 91 | 92 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 93 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 94 | 95 | pos_x = x_embed[:, :, :, None] / dim_t 96 | pos_y = y_embed[:, :, :, None] / dim_t 97 | pos_x = torch.stack( 98 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 99 | ).flatten(3) 100 | pos_y = torch.stack( 101 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 102 | ).flatten(3) 103 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 104 | return pos 105 | 106 | def forward_with_coords( 107 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 108 | ) -> torch.Tensor: 109 | eps = 1e-6 110 | # num_instances, num_points, 2 -> num_instances, num_points 111 | x_embed = coords_input[:, :, 0] 112 | y_embed = coords_input[:, :, 1] 113 | 114 | if self.normalize: 115 | x_embed = x_embed / (image_size[1] + eps) * self.scale 116 | y_embed = y_embed / (image_size[0] + eps) * self.scale 117 | 118 | dim_t = torch.arange( 119 | self.num_pos_feats, dtype=torch.float32, device=coords_input.device 120 | ) 121 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 122 | 123 | pos_x = x_embed[:, :, None] / dim_t 124 | pos_y = y_embed[:, :, None] / dim_t 125 | pos_x = torch.stack( 126 | (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=-1 127 | ).flatten( 128 | -2 129 | ) # num_instances, num_points, num_pos_feats 130 | pos_y = torch.stack( 131 | (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=-1 132 | ).flatten( 133 | -2 134 | ) # num_instances, num_points, num_pos_feats 135 | 136 | pos = torch.cat( 137 | [pos_y, pos_x], dim=-1 138 | ) # num_instances, num_points, 2 * num_pos_feats 139 | return pos 140 | 141 | def __repr__(self, _repr_indent=4): 142 | head = "Positional encoding " + self.__class__.__name__ 143 | body = [ 144 | "num_pos_feats: {}".format(self.num_pos_feats), 145 | "temperature: {}".format(self.temperature), 146 | "normalize: {}".format(self.normalize), 147 | "scale: {}".format(self.scale), 148 | ] 149 | # _repr_indent = 4 150 | lines = [head] + [" " * _repr_indent + line for line in body] 151 | return "\n".join(lines) 152 | -------------------------------------------------------------------------------- /lib/models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | 10 | def _get_activation_fn(activation): 11 | """Return an activation function given a string""" 12 | if activation == "relu": 13 | return F.relu 14 | if activation == "gelu": 15 | return F.gelu 16 | if activation == "glu": 17 | return F.glu 18 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 19 | 20 | 21 | class LoRALayer: 22 | def __init__( 23 | self, 24 | r: int, 25 | lora_alpha: int, 26 | lora_dropout: float, 27 | merge_weights: bool, 28 | ): 29 | self.r = r 30 | self.lora_alpha = lora_alpha 31 | # Optional dropout 32 | if lora_dropout > 0.0: 33 | self.lora_dropout = nn.Dropout(p=lora_dropout) 34 | else: 35 | self.lora_dropout = lambda x: x 36 | # Mark the weight as unmerged 37 | self.merged = False 38 | self.merge_weights = merge_weights 39 | 40 | 41 | class Linear(nn.Linear, LoRALayer): 42 | # LoRA implemented in a dense layer 43 | def __init__( 44 | self, 45 | in_features: int, 46 | out_features: int, 47 | r: int = 0, 48 | lora_alpha: int = 1, 49 | lora_dropout: float = 0.0, 50 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 51 | merge_weights: bool = True, 52 | **kwargs, 53 | ): 54 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 55 | LoRALayer.__init__( 56 | self, 57 | r=r, 58 | lora_alpha=lora_alpha, 59 | lora_dropout=lora_dropout, 60 | merge_weights=merge_weights, 61 | ) 62 | 63 | self.fan_in_fan_out = fan_in_fan_out 64 | # Actual trainable parameters 65 | if r > 0: 66 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) 67 | self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) 68 | self.scaling = self.lora_alpha / self.r 69 | # Freezing the pre-trained weight matrix 70 | self.weight.requires_grad = False 71 | self.reset_parameters() 72 | if fan_in_fan_out: 73 | self.weight.data = self.weight.data.transpose(0, 1) 74 | 75 | def reset_parameters(self): 76 | nn.Linear.reset_parameters(self) 77 | if hasattr(self, "lora_A"): 78 | # initialize B the same way as the default for nn.Linear and A to zero 79 | # this is different than what is described in the paper but should not affect performance 80 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 81 | nn.init.zeros_(self.lora_B) 82 | 83 | def train(self, mode: bool = True): 84 | def T(w): 85 | return w.transpose(0, 1) if self.fan_in_fan_out else w 86 | 87 | nn.Linear.train(self, mode) 88 | if mode: 89 | if self.merge_weights and self.merged: 90 | # Make sure that the weights are not merged 91 | if self.r > 0: 92 | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling 93 | self.merged = False 94 | else: 95 | if self.merge_weights and not self.merged: 96 | # Merge the weights and mark it 97 | if self.r > 0: 98 | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling 99 | self.merged = True 100 | 101 | def forward(self, x: torch.Tensor): 102 | def T(w): 103 | return w.transpose(0, 1) if self.fan_in_fan_out else w 104 | 105 | if self.r > 0 and not self.merged: 106 | result = F.linear(x, T(self.weight), bias=self.bias) 107 | result += ( 108 | self.lora_dropout(x) 109 | @ self.lora_A.transpose(0, 1) 110 | @ self.lora_B.transpose(0, 1) 111 | ) * self.scaling 112 | return result 113 | else: 114 | return F.linear(x, T(self.weight), bias=self.bias) 115 | 116 | 117 | # Ref: https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 118 | class MaskPooling(nn.Module): 119 | def __init__( 120 | self, 121 | ): 122 | super().__init__() 123 | 124 | def forward( 125 | self, 126 | x: torch.Tensor, 127 | mask: torch.Tensor, 128 | thres: float = 0.5, 129 | use_sigmoid: bool = True, 130 | ): 131 | """Forward of MaskPooling 132 | 133 | Args: 134 | x (torch.Tensor): input features with a shape of [B, C, H, W] 135 | mask (torch.Tensor): guided mask with a shape of [B, Q, H, W] 136 | thres (float, optional): binary threshold. Defaults to 0.5. 137 | use_sigmoid (bool, optional): whether to use sigmoid. Defaults to True. 138 | 139 | Returns: 140 | torch.Tensor: pooled embeddings with a shape of [B, Q, C] 141 | """ 142 | if not x.shape[-2:] == mask.shape[-2:]: 143 | # reshape mask to x 144 | mask = F.interpolate( 145 | mask, size=x.shape[-2:], mode="bilinear", align_corners=False 146 | ) 147 | 148 | with torch.no_grad(): 149 | if use_sigmoid: 150 | mask = mask.sigmoid().detach() 151 | else: 152 | mask = mask.detach() 153 | mask = mask.ge(thres).to(mask.dtype) 154 | denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 155 | mask_pooled_x = torch.einsum( 156 | "bchw,bqhw->bqc", 157 | x, 158 | mask / denorm, 159 | ) 160 | return mask_pooled_x 161 | 162 | 163 | class LayerNorm2d(nn.Module): 164 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 165 | super().__init__() 166 | self.weight = nn.Parameter(torch.ones(num_channels)) 167 | self.bias = nn.Parameter(torch.zeros(num_channels)) 168 | self.eps = eps 169 | 170 | def forward(self, x: torch.Tensor) -> torch.Tensor: 171 | u = x.mean(1, keepdim=True) 172 | s = (x - u).pow(2).mean(1, keepdim=True) 173 | x = (x - u) / torch.sqrt(s + self.eps) 174 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 175 | return x 176 | 177 | 178 | class MLP(nn.Module): 179 | """Very simple multi-layer perceptron (also called FFN)""" 180 | 181 | def __init__( 182 | self, input_dim, hidden_dim, output_dim, num_layers, activation="relu" 183 | ): 184 | super().__init__() 185 | self.num_layers = num_layers 186 | h = [hidden_dim] * (num_layers - 1) 187 | self.layers = nn.ModuleList( 188 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 189 | ) 190 | self.activation = _get_activation_fn(activation) 191 | 192 | def forward(self, x): 193 | for i, layer in enumerate(self.layers): 194 | x = self.activation(layer(x)) if i < self.num_layers - 1 else layer(x) 195 | return x 196 | 197 | 198 | class MLP_lora(nn.Module): 199 | """Very simple multi-layer perceptron (also called FFN)""" 200 | 201 | def __init__( 202 | self, input_dim, hidden_dim, output_dim, num_layers, r, activation="relu" 203 | ): 204 | super().__init__() 205 | self.num_layers = num_layers 206 | h = [hidden_dim] * (num_layers - 1) 207 | self.layers = nn.ModuleList( 208 | Linear(n, k, r) for n, k in zip([input_dim] + h, h + [output_dim]) 209 | ) 210 | self.activation = _get_activation_fn(activation) 211 | 212 | def forward(self, x): 213 | for i, layer in enumerate(self.layers): 214 | x = self.activation(layer(x)) if i < self.num_layers - 1 else layer(x) 215 | return x 216 | 217 | 218 | class Attention(nn.Module): 219 | def __init__( 220 | self, 221 | dim: int, 222 | num_heads: int = 8, 223 | qkv_bias: bool = True, 224 | attn_drop: float = 0.0, 225 | proj_drop: float = 0.0, 226 | ) -> None: 227 | super().__init__() 228 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 229 | self.num_heads = num_heads 230 | self.head_dim = dim // num_heads 231 | self.scale = self.head_dim**-0.5 232 | self.attn_drop = attn_drop 233 | 234 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 235 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 236 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 237 | self.proj = nn.Linear(dim, dim) 238 | self.proj_drop = nn.Dropout(proj_drop) 239 | 240 | def forward( 241 | self, 242 | q: torch.Tensor, 243 | k: torch.Tensor, 244 | v: torch.Tensor, 245 | attn_mask: torch.Tensor = None, 246 | pad_mask: torch.Tensor = None, 247 | ) -> torch.Tensor: 248 | B, N, C = q.shape 249 | q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 250 | k = self.k(k).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 251 | v = self.v(v).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 252 | 253 | x = F.scaled_dot_product_attention( 254 | q, 255 | k, 256 | v, 257 | dropout_p=self.attn_drop if self.training else 0.0, 258 | attn_mask=~attn_mask if attn_mask is not None else None, 259 | ) 260 | 261 | x = x.transpose(1, 2).reshape(B, N, C) 262 | x = self.proj(x) 263 | x = self.proj_drop(x) 264 | return x 265 | 266 | 267 | class CrossAttention(nn.Module): 268 | def __init__( 269 | self, 270 | dim: int, 271 | num_heads: int = 8, 272 | qkv_bias: bool = True, 273 | attn_drop: float = 0.0, 274 | proj_drop: float = 0.0, 275 | ) -> None: 276 | super().__init__() 277 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 278 | self.num_heads = num_heads 279 | self.head_dim = dim // num_heads 280 | self.scale = self.head_dim**-0.5 281 | self.attn_drop = attn_drop 282 | 283 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 284 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 285 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 286 | self.proj = nn.Linear(dim, dim) 287 | self.proj_drop = nn.Dropout(proj_drop) 288 | 289 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 290 | return tensor if pos is None else tensor + pos 291 | 292 | def forward( 293 | self, 294 | q: torch.Tensor, 295 | k: torch.Tensor, 296 | v: torch.Tensor, 297 | attn_mask: torch.Tensor = None, 298 | pad_mask: torch.Tensor = None, 299 | ) -> torch.Tensor: 300 | B, N, C = q.shape 301 | _, M, _ = k.shape 302 | q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 303 | k = self.k(k).reshape(B, M, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 304 | v = self.v(v).reshape(B, M, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 305 | 306 | x = F.scaled_dot_product_attention( 307 | q, 308 | k, 309 | v, 310 | dropout_p=self.attn_drop if self.training else 0.0, 311 | attn_mask=~attn_mask if attn_mask is not None else None, 312 | scale=self.scale, 313 | ) 314 | 315 | x = x.transpose(1, 2).reshape(B, N, C) 316 | x = self.proj(x) 317 | x = self.proj_drop(x) 318 | return x 319 | 320 | 321 | class SelfAttentionLayer(nn.Module): 322 | def __init__( 323 | self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False 324 | ): 325 | super().__init__() 326 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 327 | # self.self_attn = Attention(d_model, nhead, attn_drop=dropout, proj_drop=dropout) 328 | 329 | self.norm = nn.LayerNorm(d_model) 330 | self.dropout = nn.Dropout(dropout) 331 | 332 | self.activation = _get_activation_fn(activation) 333 | self.normalize_before = normalize_before 334 | 335 | self._reset_parameters() 336 | 337 | def _reset_parameters(self): 338 | for p in self.parameters(): 339 | if p.dim() > 1: 340 | nn.init.xavier_uniform_(p) 341 | 342 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 343 | return tensor if pos is None else tensor + pos 344 | 345 | def forward_post( 346 | self, 347 | tgt, 348 | tgt_mask: Optional[Tensor] = None, 349 | tgt_key_padding_mask: Optional[Tensor] = None, 350 | query_pos: Optional[Tensor] = None, 351 | ): 352 | q = k = self.with_pos_embed(tgt, query_pos) 353 | tgt2 = self.self_attn( 354 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 355 | )[0] 356 | # tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask, pad_mask=tgt_key_padding_mask) 357 | tgt = tgt + self.dropout(tgt2) 358 | tgt = self.norm(tgt) 359 | 360 | return tgt 361 | 362 | def forward_pre( 363 | self, 364 | tgt, 365 | tgt_mask: Optional[Tensor] = None, 366 | tgt_key_padding_mask: Optional[Tensor] = None, 367 | query_pos: Optional[Tensor] = None, 368 | ): 369 | tgt2 = self.norm(tgt) 370 | q = k = self.with_pos_embed(tgt2, query_pos) 371 | tgt2 = self.self_attn( 372 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 373 | )[0] 374 | tgt = tgt + self.dropout(tgt2) 375 | 376 | return tgt 377 | 378 | def forward( 379 | self, 380 | tgt, 381 | tgt_mask: Optional[Tensor] = None, 382 | tgt_key_padding_mask: Optional[Tensor] = None, 383 | query_pos: Optional[Tensor] = None, 384 | ): 385 | if self.normalize_before: 386 | return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) 387 | return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) 388 | 389 | 390 | class CrossAttentionLayer(nn.Module): 391 | def __init__( 392 | self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False 393 | ): 394 | super().__init__() 395 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 396 | # self.multihead_attn = CrossAttention(d_model, nhead, attn_drop=dropout, proj_drop=dropout) 397 | 398 | self.norm = nn.LayerNorm(d_model) 399 | self.dropout = nn.Dropout(dropout) 400 | 401 | self.activation = _get_activation_fn(activation) 402 | self.normalize_before = normalize_before 403 | 404 | self._reset_parameters() 405 | 406 | def _reset_parameters(self): 407 | for p in self.parameters(): 408 | if p.dim() > 1: 409 | nn.init.xavier_uniform_(p) 410 | 411 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 412 | return tensor if pos is None else tensor + pos 413 | 414 | def forward_post( 415 | self, 416 | tgt, 417 | memory, 418 | memory_mask: Optional[Tensor] = None, 419 | memory_key_padding_mask: Optional[Tensor] = None, 420 | pos: Optional[Tensor] = None, 421 | query_pos: Optional[Tensor] = None, 422 | ): 423 | tgt2 = self.multihead_attn( 424 | query=self.with_pos_embed(tgt, query_pos), 425 | key=self.with_pos_embed(memory, pos), 426 | value=memory, 427 | attn_mask=memory_mask, 428 | key_padding_mask=memory_key_padding_mask, 429 | )[0] 430 | # tgt2 = self.multihead_attn( 431 | # x=tgt, 432 | # mem=memory, 433 | # pos=query_pos, 434 | # mem_pos=pos, 435 | # attn_mask=memory_mask, 436 | # pad_mask=memory_key_padding_mask, 437 | # ) 438 | tgt = tgt + self.dropout(tgt2) 439 | tgt = self.norm(tgt) 440 | 441 | return tgt 442 | 443 | def forward_pre( 444 | self, 445 | tgt, 446 | memory, 447 | memory_mask: Optional[Tensor] = None, 448 | memory_key_padding_mask: Optional[Tensor] = None, 449 | pos: Optional[Tensor] = None, 450 | query_pos: Optional[Tensor] = None, 451 | ): 452 | tgt2 = self.norm(tgt) 453 | tgt2 = self.multihead_attn( 454 | query=self.with_pos_embed(tgt2, query_pos), 455 | key=self.with_pos_embed(memory, pos), 456 | value=memory, 457 | attn_mask=memory_mask, 458 | key_padding_mask=memory_key_padding_mask, 459 | )[0] 460 | tgt = tgt + self.dropout(tgt2) 461 | 462 | return tgt 463 | 464 | def forward( 465 | self, 466 | tgt, 467 | memory, 468 | memory_mask: Optional[Tensor] = None, 469 | memory_key_padding_mask: Optional[Tensor] = None, 470 | pos: Optional[Tensor] = None, 471 | query_pos: Optional[Tensor] = None, 472 | ): 473 | if self.normalize_before: 474 | return self.forward_pre( 475 | tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos 476 | ) 477 | return self.forward_post( 478 | tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos 479 | ) 480 | 481 | 482 | class FFNLayer(nn.Module): 483 | def __init__( 484 | self, 485 | d_model, 486 | dim_feedforward=2048, 487 | dropout=0.0, 488 | activation="relu", 489 | normalize_before=False, 490 | ): 491 | super().__init__() 492 | # Implementation of Feedforward model 493 | self.linear1 = nn.Linear(d_model, dim_feedforward) 494 | self.dropout = nn.Dropout(dropout) 495 | self.linear2 = nn.Linear(dim_feedforward, d_model) 496 | 497 | self.norm = nn.LayerNorm(d_model) 498 | 499 | self.activation = _get_activation_fn(activation) 500 | self.normalize_before = normalize_before 501 | 502 | self._reset_parameters() 503 | 504 | def _reset_parameters(self): 505 | for p in self.parameters(): 506 | if p.dim() > 1: 507 | nn.init.xavier_uniform_(p) 508 | 509 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 510 | return tensor if pos is None else tensor + pos 511 | 512 | def forward_post(self, tgt): 513 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 514 | tgt = tgt + self.dropout(tgt2) 515 | tgt = self.norm(tgt) 516 | return tgt 517 | 518 | def forward_pre(self, tgt): 519 | tgt2 = self.norm(tgt) 520 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 521 | tgt = tgt + self.dropout(tgt2) 522 | return tgt 523 | 524 | def forward(self, tgt): 525 | if self.normalize_before: 526 | return self.forward_pre(tgt) 527 | return self.forward_post(tgt) 528 | 529 | 530 | class FFNLayer_lora(nn.Module): 531 | def __init__( 532 | self, 533 | d_model, 534 | dim_feedforward=2048, 535 | dropout=0.0, 536 | activation="relu", 537 | normalize_before=False, 538 | r=0, 539 | ): 540 | super().__init__() 541 | # Implementation of Feedforward model 542 | self.linear1 = Linear(d_model, dim_feedforward, r) 543 | self.dropout = nn.Dropout(dropout) 544 | self.linear2 = Linear(dim_feedforward, d_model, r) 545 | 546 | self.norm = nn.LayerNorm(d_model) 547 | 548 | self.activation = _get_activation_fn(activation) 549 | self.normalize_before = normalize_before 550 | 551 | self._reset_parameters() 552 | 553 | def _reset_parameters(self): 554 | for p in self.parameters(): 555 | if p.dim() > 1: 556 | nn.init.xavier_uniform_(p) 557 | 558 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 559 | return tensor if pos is None else tensor + pos 560 | 561 | def forward_post(self, tgt): 562 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 563 | tgt = tgt + self.dropout(tgt2) 564 | tgt = self.norm(tgt) 565 | return tgt 566 | 567 | def forward_pre(self, tgt): 568 | tgt2 = self.norm(tgt) 569 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 570 | tgt = tgt + self.dropout(tgt2) 571 | return tgt 572 | 573 | def forward(self, tgt): 574 | if self.normalize_before: 575 | return self.forward_pre(tgt) 576 | return self.forward_post(tgt) 577 | 578 | 579 | class TransformerDecoderLayer(nn.Module): 580 | def __init__( 581 | self, 582 | d_model: int = 256, 583 | nhead: int = 4, 584 | dim_feedforward: int = 1024, 585 | dropout=0.0, 586 | activation="gelu", 587 | normalize_before=False, 588 | ): 589 | super().__init__() 590 | self.cross_attn = CrossAttentionLayer( 591 | d_model, nhead, dropout, activation, normalize_before 592 | ) 593 | self.self_attn = SelfAttentionLayer( 594 | d_model, nhead, dropout, activation, normalize_before 595 | ) 596 | self.ffn = FFNLayer( 597 | d_model, dim_feedforward, dropout, activation, normalize_before 598 | ) 599 | 600 | def forward( 601 | self, 602 | tgt: Tensor, 603 | memory: Tensor, 604 | memory_mask: Optional[Tensor] = None, 605 | memory_key_padding_mask: Optional[Tensor] = None, 606 | tgt_mask: Optional[Tensor] = None, 607 | tgt_key_padding_mask: Optional[Tensor] = None, 608 | pos: Optional[Tensor] = None, 609 | query_pos: Optional[Tensor] = None, 610 | ): 611 | tgt, _ = self.cross_attn( 612 | tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos 613 | ) 614 | tgt = self.self_attn(tgt, tgt_mask, tgt_key_padding_mask, query_pos) 615 | tgt = self.ffn(tgt) 616 | 617 | return tgt 618 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt import ( 2 | SIMPLE_TEMPLATES, 3 | VILD_TEMPLATES, 4 | OPENAI_IMAGENET_TEMPLATES, 5 | OPENAI_IMAGENET_VILD_TEMPLATES, 6 | ) 7 | from .post_process import ( 8 | mask_nms, 9 | batched_mask_nms, 10 | pairwise_iou, 11 | get_classification_logits_fcclip, 12 | sem_seg_postprocess, 13 | ) 14 | from .misc import ( 15 | calculate_uncertainty, 16 | get_uncertain_point_coords_with_randomness, 17 | is_dist_avail_and_initialized, 18 | nested_tensor_from_tensor_list, 19 | point_sample, 20 | ) 21 | from .config import add_ovseg_config 22 | from .test_time_augmentation import SemanticSegmentorWithTTA -------------------------------------------------------------------------------- /lib/utils/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | 4 | def add_ovseg_config(cfg): 5 | """ 6 | Add config for uniovseg. 7 | """ 8 | cfg.INPUT.DATASET_ROOT = "/datasets/sharegpt4v" 9 | cfg.INPUT.DATASET_URL = [ 10 | ["/datasets/SA-1B/split1-2m", "datadict_0p5.parquet"], 11 | ] 12 | cfg.INPUT.DATASET_JSON = "/vepfs/home/wangzhaoqing/uni-ovseg/sa1b.json" 13 | cfg.INPUT.FEW_SHOT_JSON = [ 14 | "/workspace/pretrains/coco_fewshot/openvocab_coco_2017_train_panoptic_with_sem_seg_0.1.json", 15 | ] 16 | cfg.INPUT.DATASET_MAPPER_NAME = "sa1b" 17 | cfg.INPUT.COLOR_AUG_SSD = True 18 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 19 | cfg.INPUT.SIZE_DIVISIBILITY = 32 20 | cfg.INPUT.IMG_SIZE = 1024 21 | cfg.INPUT.CROP_SIZE = 1024 22 | cfg.INPUT.MIN_SCALE = 0.8 23 | cfg.INPUT.MAX_SCALE = 1.2 24 | cfg.INPUT.MIN_AREA_RATIO = 0.001 25 | cfg.INPUT.MAX_AREA_RATIO = 0.8 26 | cfg.INPUT.MAX_INSTANCE = 40 27 | 28 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 29 | cfg.SOLVER.OPTIMIZER = "ADAMW" 30 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 31 | cfg.SOLVER.POLY_LR_CONSTANT_ENDING = 0.0 32 | cfg.SOLVER.POLY_LR_POWER = 0.9 33 | 34 | cfg.MODEL.META_ARCHITECTURE = "UniOVSeg_S1" 35 | cfg.MODEL.OVSEG = CN() 36 | cfg.MODEL.OVSEG.CLIP_MODEL_NAME = "convnext_large_d_320" 37 | cfg.MODEL.OVSEG.CLIP_PRETRAINED_WEIGHTS = ( 38 | "/workspace/pretrains/convnext_large_d_320.laion2B-s29B-b131K-ft-soup.pth" 39 | ) 40 | cfg.MODEL.OVSEG.AUX_MODEL_NAME = "convnext_xxlarge" 41 | cfg.MODEL.OVSEG.AUX_PRETRAINED_WEIGHTS = ( 42 | "/workspace/pretrains/convnext_xxlarge.laion2B-s34B-b82K-augreg-soup.pth" 43 | ) 44 | cfg.MODEL.OVSEG.PROMPT_ENCODER_NAME = "PromptEncoder" 45 | cfg.MODEL.OVSEG.PIXEL_DECODER_NAME = "BasePixelDecoder" 46 | cfg.MODEL.OVSEG.TRANSFORMER_ENC_LAYERS = 6 47 | cfg.MODEL.OVSEG.IN_FEATURES = [ 48 | "res2", 49 | "res3", 50 | "res4", 51 | "res5", 52 | ] 53 | cfg.MODEL.OVSEG.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = [ 54 | "res3", 55 | "res4", 56 | "res5", 57 | ] 58 | cfg.MODEL.OVSEG.COMMON_STRIDE = 4 59 | cfg.MODEL.OVSEG.TRANSFORMER_IN_FEATURE = "multi_scale_pixel_decoder" 60 | cfg.MODEL.OVSEG.TRANSFORMER_DECODER_NAME = "MultiScaleMaskDecoder" 61 | cfg.MODEL.OVSEG.MASK_DIM = 256 62 | cfg.MODEL.OVSEG.CONVS_DIM = 256 63 | cfg.MODEL.OVSEG.NORM = "GN" 64 | cfg.MODEL.OVSEG.EMBED_DIM = 256 65 | cfg.MODEL.OVSEG.NHEADS = 8 66 | cfg.MODEL.OVSEG.DIM_FEEDFORWARD = 2048 67 | cfg.MODEL.OVSEG.PRE_NORM = False 68 | cfg.MODEL.OVSEG.DROPOUT = 0.0 69 | cfg.MODEL.OVSEG.ENFORCE_INPUT_PROJ = False 70 | cfg.MODEL.OVSEG.DEEP_SUPERVISION = True 71 | cfg.MODEL.OVSEG.NUM_MASKS = 4 72 | cfg.MODEL.OVSEG.RANK = 8 73 | cfg.MODEL.OVSEG.LORA_INIT = False 74 | cfg.MODEL.OVSEG.CRITERION_SEG = "Many2ManySetCriterion" 75 | cfg.MODEL.OVSEG.CRITERION_ALIGN = "MaskTextAlignCriterion" 76 | cfg.MODEL.OVSEG.DICE_WEIGHT = 1.0 77 | cfg.MODEL.OVSEG.MASK_WEIGHT = 1.0 78 | cfg.MODEL.OVSEG.IOU_WEIGHT = 1.0 79 | cfg.MODEL.OVSEG.ALIGN_WEIGHT = 1.0 80 | cfg.MODEL.OVSEG.MATCHER_NUM_POINTS = 5000 81 | cfg.MODEL.OVSEG.MATCHER_THRES_POS = 0.7 82 | cfg.MODEL.OVSEG.TRAIN_NUM_POINTS = 12544 # 800 * 800 // (8 * 8) 83 | cfg.MODEL.OVSEG.OVERSAMPLE_RATIO = 3.0 84 | cfg.MODEL.OVSEG.IMPORTANCE_SAMPLE_RATIO = 0.75 85 | cfg.MODEL.OVSEG.DEC_LAYERS = 7 86 | cfg.MODEL.OVSEG.LOSS_TOPK = 1.0 87 | cfg.MODEL.OVSEG.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = True 88 | cfg.MODEL.OVSEG.PTS_PER_SIDE = (10, 10) 89 | cfg.MODEL.OVSEG.CLIP_DIM = 1024 90 | cfg.MODEL.OVSEG.SIZE_DIVISIBILITY = 32 91 | cfg.MODEL.OVSEG.INPUT_SIZES = [896, 1024] 92 | 93 | cfg.MODEL.OVSEG.TEST = CN() 94 | cfg.MODEL.OVSEG.TEST.PTS_PER_SIDE = (20, 20) 95 | cfg.MODEL.OVSEG.TEST.OBJECT_MASK_THRESHOLD = 0.5 96 | cfg.MODEL.OVSEG.TEST.OVERLAP_THRESHOLD = 0.5 97 | cfg.MODEL.OVSEG.TEST.SEMANTIC_ON = False 98 | cfg.MODEL.OVSEG.TEST.INSTANCE_ON = False 99 | cfg.MODEL.OVSEG.TEST.PANOPTIC_ON = False 100 | cfg.MODEL.OVSEG.TEST.MASKCLS_ON = False 101 | cfg.MODEL.OVSEG.TEST.AUTOLABEL_ON = True 102 | cfg.MODEL.OVSEG.TEST.AUTOLABEL_TYPE = "panoptic-point" 103 | cfg.MODEL.OVSEG.TEST.AUTOLABEL_SAVE = False 104 | -------------------------------------------------------------------------------- /lib/utils/debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | 5 | 6 | def visualize_boxes(tensor, figsize=(10, 10), image_size=(1024, 1024)): 7 | """ 8 | 可视化边界框 9 | 10 | 参数: 11 | - tensor: 形状为 [1, 400, 4] 的 numpy 数组,表示边界框坐标 12 | - figsize: 图像的大小,以英寸为单位 13 | - image_size: 背景图像的尺寸,用于规范化边界框坐标 14 | """ 15 | tensor = tensor.cpu() 16 | fig, ax = plt.subplots(1, figsize=figsize) 17 | # 设置坐标轴的范围 18 | ax.set_xlim(0, image_size[0]) 19 | ax.set_ylim(image_size[1], 0) 20 | 21 | # 遍历 tensor 中的每个边界框并绘制 22 | for box in tensor[0]: 23 | # 假设坐标是相对于 image_size 的图像的 24 | x_min, y_min, x_max, y_max = box# * np.array(image_size*2) 25 | # 创建一个矩形并添加到轴上 26 | rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=1, edgecolor='r', facecolor='none') 27 | ax.add_patch(rect) 28 | 29 | plt.savefig("debug.png") 30 | -------------------------------------------------------------------------------- /lib/utils/misc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | from torch.distributed.nn import all_gather 7 | import torchvision 8 | from torch import Tensor 9 | from detectron2.utils.comm import get_world_size, get_rank 10 | from detectron2.layers import cat, shapes_to_tensor 11 | 12 | 13 | def calculate_uncertainty(logits): 14 | """ 15 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 16 | foreground class in `classes`. 17 | Args: 18 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 19 | class-agnostic, where R is the total number of predicted masks in all images and C is 20 | the number of foreground classes. The values are logits. 21 | Returns: 22 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 23 | the most uncertain locations having the highest uncertainty score. 24 | """ 25 | assert logits.shape[1] == 1 26 | gt_class_logits = logits.clone() 27 | return -(torch.abs(gt_class_logits)) 28 | 29 | 30 | def point_sample(input, point_coords, **kwargs): 31 | """ 32 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 33 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 34 | [0, 1] x [0, 1] square. 35 | 36 | Args: 37 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 38 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 39 | [0, 1] x [0, 1] normalized point coordinates. 40 | 41 | Returns: 42 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 43 | features for points in `point_coords`. The features are obtained via bilinear 44 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 45 | """ 46 | add_dim = False 47 | if point_coords.dim() == 3: 48 | add_dim = True 49 | point_coords = point_coords.unsqueeze(2) 50 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 51 | if add_dim: 52 | output = output.squeeze(3) 53 | return output 54 | 55 | 56 | def get_uncertain_point_coords_with_randomness( 57 | coarse_logits, 58 | uncertainty_func, 59 | num_points, 60 | oversample_ratio, 61 | importance_sample_ratio, 62 | ): 63 | """ 64 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 65 | are calculated for each point using 'uncertainty_func' function that takes point's logit 66 | prediction as input. 67 | See PointRend paper for details. 68 | 69 | Args: 70 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 71 | class-specific or class-agnostic prediction. 72 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 73 | contains logit predictions for P points and returns their uncertainties as a Tensor of 74 | shape (N, 1, P). 75 | num_points (int): The number of points P to sample. 76 | oversample_ratio (int): Oversampling parameter. 77 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 78 | 79 | Returns: 80 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 81 | sampled points. 82 | """ 83 | assert oversample_ratio >= 1 84 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 85 | num_boxes = coarse_logits.shape[0] 86 | num_sampled = int(num_points * oversample_ratio) 87 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 88 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 89 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 90 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 91 | # to incorrect results. 92 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 93 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 94 | # However, if we calculate uncertainties for the coarse predictions first, 95 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 96 | point_uncertainties = uncertainty_func(point_logits) 97 | num_uncertain_points = int(importance_sample_ratio * num_points) 98 | num_random_points = num_points - num_uncertain_points 99 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 100 | shift = num_sampled * torch.arange( 101 | num_boxes, dtype=torch.long, device=coarse_logits.device 102 | ) 103 | idx += shift[:, None] 104 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 105 | num_boxes, num_uncertain_points, 2 106 | ) 107 | if num_random_points > 0: 108 | point_coords = cat( 109 | [ 110 | point_coords, 111 | torch.rand( 112 | num_boxes, num_random_points, 2, device=coarse_logits.device 113 | ), 114 | ], 115 | dim=1, 116 | ) 117 | return point_coords 118 | 119 | 120 | def all_gather_no_grad(tensor): 121 | world_size = get_world_size() 122 | if world_size == 1: 123 | return tensor 124 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(world_size)] 125 | dist.all_gather(gathered_tensor, tensor) 126 | gathered_tensor = torch.cat(gathered_tensor, dim=0) 127 | return gathered_tensor 128 | 129 | 130 | def gather_no_grad(tensor): 131 | world_size = get_world_size() 132 | if world_size == 1: 133 | return tensor 134 | # get tensor shape from all rank 135 | rank = get_rank() 136 | if rank == 0: 137 | # gather shape 138 | gathered_length = [None for _ in range(world_size)] 139 | dist.gather_object(tensor.shape[0], gathered_length, dst=0) 140 | # gather tensor 141 | gathered_tensor = [ 142 | torch.zeros(gathered_length[i], device=tensor.device) 143 | for i in range(world_size) 144 | ] 145 | dist.gather(tensor, gathered_tensor, dst=0) 146 | else: 147 | dist.gather_object(tensor.shape[0], dst=0) 148 | dist.gather(tensor, dst=0) 149 | gathered_tensor = torch.cat(gathered_tensor, dim=0) 150 | return gathered_tensor 151 | 152 | 153 | def all_gather_with_grad(tensor): 154 | world_size = get_world_size() 155 | if world_size == 1: 156 | return tensor 157 | gathered_tensor = all_gather(tensor) 158 | gathered_tensor = torch.cat(gathered_tensor, dim=0) 159 | return gathered_tensor 160 | 161 | 162 | def _max_by_axis(the_list): 163 | # type: (List[List[int]]) -> List[int] 164 | maxes = the_list[0] 165 | for sublist in the_list[1:]: 166 | for index, item in enumerate(sublist): 167 | maxes[index] = max(maxes[index], item) 168 | return maxes 169 | 170 | 171 | class NestedTensor(object): 172 | def __init__(self, tensors, mask: Optional[Tensor]): 173 | self.tensors = tensors 174 | self.mask = mask 175 | 176 | def to(self, device): 177 | cast_tensor = self.tensors.to(device) 178 | mask = self.mask 179 | if mask is not None: 180 | assert mask is not None 181 | cast_mask = mask.to(device) 182 | else: 183 | cast_mask = None 184 | return NestedTensor(cast_tensor, cast_mask) 185 | 186 | def decompose(self): 187 | return self.tensors, self.mask 188 | 189 | def __repr__(self): 190 | return str(self.tensors) 191 | 192 | 193 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 194 | # TODO make this more general 195 | if tensor_list[0].ndim == 3: 196 | if torchvision._is_tracing(): 197 | # nested_tensor_from_tensor_list() does not export well to ONNX 198 | # call _onnx_nested_tensor_from_tensor_list() instead 199 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 200 | 201 | # TODO make it support different-sized images 202 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 203 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 204 | batch_shape = [len(tensor_list)] + max_size 205 | b, c, h, w = batch_shape 206 | dtype = tensor_list[0].dtype 207 | device = tensor_list[0].device 208 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 209 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 210 | for img, pad_img, m in zip(tensor_list, tensor, mask): 211 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 212 | m[: img.shape[1], : img.shape[2]] = False 213 | else: 214 | raise ValueError("not supported") 215 | return NestedTensor(tensor, mask) 216 | 217 | 218 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 219 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 220 | @torch.jit.unused 221 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 222 | max_size = [] 223 | for i in range(tensor_list[0].dim()): 224 | max_size_i = torch.max( 225 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 226 | ).to(torch.int64) 227 | max_size.append(max_size_i) 228 | max_size = tuple(max_size) 229 | 230 | # work around for 231 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 232 | # m[: img.shape[1], :img.shape[2]] = False 233 | # which is not yet supported in onnx 234 | padded_imgs = [] 235 | padded_masks = [] 236 | for img in tensor_list: 237 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 238 | padded_img = torch.nn.functional.pad( 239 | img, (0, padding[2], 0, padding[1], 0, padding[0]) 240 | ) 241 | padded_imgs.append(padded_img) 242 | 243 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 244 | padded_mask = torch.nn.functional.pad( 245 | m, (0, padding[2], 0, padding[1]), "constant", 1 246 | ) 247 | padded_masks.append(padded_mask.to(torch.bool)) 248 | 249 | tensor = torch.stack(padded_imgs) 250 | mask = torch.stack(padded_masks) 251 | 252 | return NestedTensor(tensor, mask=mask) 253 | 254 | 255 | def is_dist_avail_and_initialized(): 256 | if not dist.is_available(): 257 | return False 258 | if not dist.is_initialized(): 259 | return False 260 | return True 261 | -------------------------------------------------------------------------------- /lib/utils/post_process.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def sem_seg_postprocess( 7 | result: torch.Tensor, 8 | img_size: Tuple[int, int], 9 | output_size: torch.Tensor, 10 | ): 11 | """ 12 | Return semantic segmentation predictions in the original resolution. 13 | 14 | The input images are often resized when entering semantic segmentor. Moreover, in same 15 | cases, they also padded inside segmentor to be divisible by maximum network stride. 16 | As a result, we often need the predictions of the segmentor in a different 17 | resolution from its inputs. 18 | 19 | Args: 20 | result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), 21 | where C is the number of classes, and H, W are the height and width of the prediction. 22 | img_size (tuple): image size that segmentor is taking as input. 23 | output_height, output_width: the desired output resolution. 24 | 25 | Returns: 26 | semantic segmentation prediction (Tensor): A tensor of the shape 27 | (C, output_height, output_width) that contains per-pixel soft predictions. 28 | """ 29 | output_width = output_size[0].item() 30 | output_height = output_size[1].item() 31 | if img_size[0] == img_size[1]: 32 | max_side_length = max(output_height, output_width) 33 | result = F.interpolate( 34 | result.unsqueeze(0), 35 | size=(max_side_length, max_side_length), 36 | mode="bilinear", 37 | align_corners=False, 38 | ).squeeze(0) 39 | result = result[:, :output_height, :output_width] 40 | else: 41 | result = F.interpolate( 42 | result.unsqueeze(0), 43 | size=(output_height, output_width), 44 | mode="bilinear", 45 | align_corners=False, 46 | ).squeeze(0) 47 | return result 48 | 49 | 50 | @torch.jit.script 51 | def pairwise_iou(masks1: torch.Tensor, masks2: torch.Tensor): 52 | masks1 = masks1.flatten(1) 53 | masks2 = masks2.flatten(1) 54 | 55 | intersection = torch.einsum("nc,mc->nm", masks1, masks2) 56 | union = masks1.sum(-1)[:, None] + masks2.sum(-1)[None, :] 57 | iou = (intersection + 1e-6) / (union - intersection + 1e-6) 58 | 59 | return iou 60 | 61 | 62 | @torch.jit.script 63 | def pairwise_inner(masks1: torch.Tensor, masks2: torch.Tensor): 64 | masks1 = masks1.flatten(1) 65 | masks2 = masks2.flatten(1) 66 | 67 | inter = torch.einsum("nc,mc->nm", masks1, masks2) 68 | inner = (inter + 1e-6) / (masks2.sum(dim=1) + 1e-6) 69 | 70 | return inner 71 | 72 | 73 | def nms(scores, iou_matrix, iou_threshold=0.7): 74 | # 按分数从高到低排序mask的索引 75 | _, order = scores.sort(descending=True) 76 | 77 | keep = [] # 保存NMS后的mask索引 78 | while order.numel() > 0: 79 | i = order[0] # 当前最高分数的mask索引 80 | keep.append(i) # 将当前最高分数的mask索引加入到keep列表 81 | 82 | if order.numel() == 1: # 如果只剩下一个元素,则直接保留 83 | break 84 | 85 | # 计算当前mask与其他所有mask的IOU 86 | current_iou = iou_matrix[i, order[1:]] 87 | 88 | # 筛选出与当前mask IOU小于阈值的mask,它们不会被抑制 89 | remain_inds = torch.nonzero(current_iou < iou_threshold).squeeze(dim=1) 90 | 91 | # 更新order,只保留那些没有被当前mask抑制的mask的索引 92 | order = order[remain_inds + 1] # 加1因为iou_matrix中排除了自己 93 | 94 | # 创建一个新的数组来存储NMS后的mask 95 | keep = torch.as_tensor(keep, dtype=torch.int64, device=scores.device) 96 | return keep 97 | 98 | 99 | def inner_nms(scores, iou_matrix, ratio_matrix, iou_threshold=0.7, ratio_threshold=0.9): 100 | # 按分数从高到低排序mask的索引 101 | _, order = scores.sort(descending=True) 102 | 103 | keep = [] # 保存NMS后的mask索引 104 | while order.numel() > 0: 105 | i = order[0] # 当前最高分数的mask索引 106 | keep.append(i) # 将当前最高分数的mask索引加入到keep列表 107 | 108 | if order.numel() == 1: # 如果只剩下一个元素,则直接保留 109 | break 110 | 111 | # 计算当前mask与其他所有mask的IOU 112 | current_iou = iou_matrix[i, order[1:]] 113 | current_ratio = ratio_matrix[i, order[1:]] 114 | 115 | # 筛选出与当前mask IOU小于阈值的mask,它们不会被抑制 116 | remain_inds = torch.nonzero( 117 | (current_iou < iou_threshold) & (current_ratio < ratio_threshold) 118 | ).squeeze(dim=1) 119 | 120 | # 更新order,只保留那些没有被当前mask抑制的mask的索引 121 | order = order[remain_inds + 1] # 加1因为iou_matrix中排除了自己 122 | 123 | # 创建一个新的数组来存储NMS后的mask 124 | keep = torch.as_tensor(keep, dtype=torch.int64, device=scores.device) 125 | return keep 126 | 127 | 128 | def mask_nms( 129 | masks: torch.Tensor, 130 | scores: torch.Tensor, 131 | iou_threshold: float = 0.75, 132 | inner_threshold: float = 0.9, 133 | nms_type: str = "nms", 134 | downsample: float = 1.0, 135 | ): 136 | """ 137 | Performs non-maximum suppression (NMS) on the masks according to their intersection-over-union (IoU) 138 | overlap, independently for each instance. Masks are expected to be in ``(N, H, W)`` format, where N is 139 | the number of instances. 140 | 141 | Args: 142 | masks (Tensor): A tensor of shape ``(N, H, W)``, representing N masks of height H and width W. 143 | scores (Tensor): A tensor of shape ``(N,)`` representing the score of each mask. 144 | iou_threshold (float): A float representing the IoU threshold for deciding whether boxes overlap too 145 | much with respect to each other. 146 | 147 | Returns: 148 | Tensor: A tensor of shape ``(N,)`` representing the indices of the elements that have been kept by NMS. 149 | """ 150 | # downsample mask 151 | if downsample < 1.0: 152 | masks = F.interpolate( 153 | masks.unsqueeze(0), scale_factor=downsample, mode="bilinear" 154 | ).squeeze(0) 155 | 156 | # flatten all masks 157 | masks = masks.reshape(masks.shape[0], -1) 158 | masks = masks.sigmoid().ge(0.5).float() 159 | 160 | # nms 161 | if nms_type == "nms": 162 | iou_matrix = pairwise_iou(masks, masks) 163 | keep = nms(scores, iou_matrix, iou_threshold) 164 | else: 165 | iou_matrix = pairwise_iou(masks, masks) 166 | inner_matrix = pairwise_inner(masks, masks) 167 | keep = inner_nms( 168 | scores, iou_matrix, inner_matrix, iou_threshold, inner_threshold 169 | ) 170 | 171 | return keep 172 | 173 | 174 | def batched_mask_nms( 175 | masks: torch.Tensor, 176 | scores: torch.Tensor, 177 | category_idxs: torch.Tensor, 178 | iou_threshold: float = 0.75, 179 | downsample: float = 1.0, 180 | ): 181 | """ 182 | Performs batched non-maximum suppression (NMS) on the masks according to their intersection-over-union (IoU) 183 | overlap, independently for each category. Masks are expected to be in ``(N, H, W)`` format, where N is 184 | the number of instances. 185 | 186 | Args: 187 | masks (Tensor): A tensor of shape ``(N, H, W)``, representing N masks of height H and width W. 188 | scores (Tensor): A tensor of shape ``(N,)`` representing the score of each mask. 189 | category_idxs (Tensor): A tensor of shape ``(N,)`` representing the category index for each mask. 190 | iou_threshold (float): A float representing the IoU threshold for deciding whether boxes overlap too 191 | much with respect to each other. 192 | 193 | Returns: 194 | Tensor: A tensor of shape ``(N,)`` representing the indices of the elements that have been kept by NMS. 195 | """ 196 | 197 | if masks.numel() == 0: 198 | return torch.empty((0,), dtype=torch.int64, device=masks.device) 199 | 200 | # downsample mask 201 | if downsample < 1.0: 202 | masks = F.interpolate( 203 | masks.unsqueeze(0), scale_factor=downsample, mode="bilinear" 204 | ).squeeze(0) 205 | 206 | # Flatten masks and threshold 207 | masks_flat = masks.reshape(masks.shape[0], -1) 208 | masks_flat = masks_flat.sigmoid().ge(0.5).float() 209 | 210 | # Initialize tensor to keep track of the indices to keep 211 | keep_indices = torch.empty((0,), dtype=torch.int64, device=masks.device) 212 | 213 | # Process each category separately 214 | for category in torch.unique(category_idxs): 215 | # Filter masks and scores for the current category 216 | category_mask = category_idxs == category 217 | masks_category = masks_flat[category_mask] 218 | scores_category = scores[category_mask] 219 | 220 | # Compute pairwise IoU for masks in the current category 221 | iou = pairwise_iou(masks_category, masks_category) 222 | 223 | # Discard overlaps 224 | iou.triu_(diagonal=1) 225 | iou_max, _ = iou.max(dim=0) 226 | category_keep = (iou_max <= iou_threshold).nonzero(as_tuple=False).squeeze(1) 227 | 228 | # Keep top scoring masks within this category 229 | if category_keep.numel() > 0: 230 | scores_keep = scores_category[category_keep] 231 | _, idx = scores_keep.sort(0, descending=True) 232 | category_keep = category_keep[idx] 233 | 234 | # Add indices (adjusted to original indexing) to keep_indices 235 | keep_indices = torch.cat( 236 | ( 237 | keep_indices, 238 | torch.nonzero(category_mask, as_tuple=False).squeeze(1)[category_keep], 239 | ) 240 | ) 241 | 242 | return keep_indices 243 | 244 | 245 | def get_classification_logits_fcclip( 246 | x, text_classifier, logit_scale=None, num_templates=None 247 | ): 248 | """ 249 | x in shape of [B, *, C] 250 | text_classifier in shape of [num_classes, C] 251 | logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201 252 | return: [B, *, num_classes] 253 | """ 254 | 255 | # Normalize feature vectors 256 | x = F.normalize(x, dim=-1) 257 | 258 | # Compute initial logits 259 | pred_logits = x @ text_classifier.transpose(-2, -1) # Shape: B, *, N + 1 260 | 261 | # Efficiently compute the max ensemble (used in OpenSeg/ODISE) 262 | max_logits = [] 263 | cur_idx = 0 264 | for num_t in num_templates: 265 | max_logits.append(pred_logits[:, :, cur_idx : cur_idx + num_t].amax(dim=-1)) 266 | cur_idx += num_t 267 | final_pred_logits = torch.stack(max_logits, dim=-1) 268 | 269 | # Apply logit scale 270 | if logit_scale is not None: 271 | logit_scale = torch.clamp(logit_scale.exp(), max=100) 272 | final_pred_logits *= logit_scale 273 | 274 | return final_pred_logits 275 | -------------------------------------------------------------------------------- /lib/utils/prompt.py: -------------------------------------------------------------------------------- 1 | SIMPLE_TEMPLATES = ( 2 | lambda c: f"a photo of a {c}.", 3 | lambda c: f"This is a photo of a {c}", 4 | lambda c: f"There is a {c} in the scene", 5 | lambda c: f"There is the {c} in the scene", 6 | lambda c: f"a photo of a {c} in the scene", 7 | ) 8 | 9 | VILD_TEMPLATES = ( 10 | lambda c: f"a photo of a {c}.", 11 | lambda c: f"This is a photo of a {c}", 12 | lambda c: f"There is a {c} in the scene", 13 | lambda c: f"There is the {c} in the scene", 14 | lambda c: f"a photo of a {c} in the scene", 15 | lambda c: f"a photo of a small {c}.", 16 | lambda c: f"a photo of a medium {c}.", 17 | lambda c: f"a photo of a large {c}.", 18 | lambda c: f"This is a photo of a small {c}.", 19 | lambda c: f"This is a photo of a medium {c}.", 20 | lambda c: f"This is a photo of a large {c}.", 21 | lambda c: f"There is a small {c} in the scene.", 22 | lambda c: f"There is a medium {c} in the scene.", 23 | lambda c: f"There is a large {c} in the scene.", 24 | ) 25 | 26 | OPENAI_IMAGENET_TEMPLATES = ( 27 | lambda c: f"a photo of a {c}.", 28 | lambda c: f"a bad photo of a {c}.", 29 | lambda c: f"a photo of many {c}.", 30 | lambda c: f"a sculpture of a {c}.", 31 | lambda c: f"a photo of the hard to see {c}.", 32 | lambda c: f"a low resolution photo of the {c}.", 33 | lambda c: f"a rendering of a {c}.", 34 | lambda c: f"graffiti of a {c}.", 35 | lambda c: f"a bad photo of the {c}.", 36 | lambda c: f"a cropped photo of the {c}.", 37 | lambda c: f"a tattoo of a {c}.", 38 | lambda c: f"the embroidered {c}.", 39 | lambda c: f"a photo of a hard to see {c}.", 40 | lambda c: f"a bright photo of a {c}.", 41 | lambda c: f"a photo of a clean {c}.", 42 | lambda c: f"a photo of a dirty {c}.", 43 | lambda c: f"a dark photo of the {c}.", 44 | lambda c: f"a drawing of a {c}.", 45 | lambda c: f"a photo of my {c}.", 46 | lambda c: f"the plastic {c}.", 47 | lambda c: f"a photo of the cool {c}.", 48 | lambda c: f"a close-up photo of a {c}.", 49 | lambda c: f"a black and white photo of the {c}.", 50 | lambda c: f"a painting of the {c}.", 51 | lambda c: f"a painting of a {c}.", 52 | lambda c: f"a pixelated photo of the {c}.", 53 | lambda c: f"a sculpture of the {c}.", 54 | lambda c: f"a bright photo of the {c}.", 55 | lambda c: f"a cropped photo of a {c}.", 56 | lambda c: f"a plastic {c}.", 57 | lambda c: f"a photo of the dirty {c}.", 58 | lambda c: f"a jpeg corrupted photo of a {c}.", 59 | lambda c: f"a blurry photo of the {c}.", 60 | lambda c: f"a photo of the {c}.", 61 | lambda c: f"a good photo of the {c}.", 62 | lambda c: f"a rendering of the {c}.", 63 | lambda c: f"a {c} in a video game.", 64 | lambda c: f"a photo of one {c}.", 65 | lambda c: f"a doodle of a {c}.", 66 | lambda c: f"a close-up photo of the {c}.", 67 | lambda c: f"a photo of a {c}.", 68 | lambda c: f"the origami {c}.", 69 | lambda c: f"the {c} in a video game.", 70 | lambda c: f"a sketch of a {c}.", 71 | lambda c: f"a doodle of the {c}.", 72 | lambda c: f"a origami {c}.", 73 | lambda c: f"a low resolution photo of a {c}.", 74 | lambda c: f"the toy {c}.", 75 | lambda c: f"a rendition of the {c}.", 76 | lambda c: f"a photo of the clean {c}.", 77 | lambda c: f"a photo of a large {c}.", 78 | lambda c: f"a rendition of a {c}.", 79 | lambda c: f"a photo of a nice {c}.", 80 | lambda c: f"a photo of a weird {c}.", 81 | lambda c: f"a blurry photo of a {c}.", 82 | lambda c: f"a cartoon {c}.", 83 | lambda c: f"art of a {c}.", 84 | lambda c: f"a sketch of the {c}.", 85 | lambda c: f"a embroidered {c}.", 86 | lambda c: f"a pixelated photo of a {c}.", 87 | lambda c: f"itap of the {c}.", 88 | lambda c: f"a jpeg corrupted photo of the {c}.", 89 | lambda c: f"a good photo of a {c}.", 90 | lambda c: f"a plushie {c}.", 91 | lambda c: f"a photo of the nice {c}.", 92 | lambda c: f"a photo of the small {c}.", 93 | lambda c: f"a photo of the weird {c}.", 94 | lambda c: f"the cartoon {c}.", 95 | lambda c: f"art of the {c}.", 96 | lambda c: f"a drawing of the {c}.", 97 | lambda c: f"a photo of the large {c}.", 98 | lambda c: f"a black and white photo of a {c}.", 99 | lambda c: f"the plushie {c}.", 100 | lambda c: f"a dark photo of a {c}.", 101 | lambda c: f"itap of a {c}.", 102 | lambda c: f"graffiti of the {c}.", 103 | lambda c: f"a toy {c}.", 104 | lambda c: f"itap of my {c}.", 105 | lambda c: f"a photo of a cool {c}.", 106 | lambda c: f"a photo of a small {c}.", 107 | lambda c: f"a tattoo of the {c}.", 108 | ) 109 | 110 | OPENAI_IMAGENET_VILD_TEMPLATES = ( 111 | lambda c: f"a photo of a {c}.", 112 | lambda c: f"This is a photo of a {c}", 113 | lambda c: f"There is a {c} in the scene", 114 | lambda c: f"There is the {c} in the scene", 115 | lambda c: f"a photo of a {c} in the scene", 116 | lambda c: f"a photo of a small {c}.", 117 | lambda c: f"a photo of a medium {c}.", 118 | lambda c: f"a photo of a large {c}.", 119 | lambda c: f"This is a photo of a small {c}.", 120 | lambda c: f"This is a photo of a medium {c}.", 121 | lambda c: f"This is a photo of a large {c}.", 122 | lambda c: f"There is a small {c} in the scene.", 123 | lambda c: f"There is a medium {c} in the scene.", 124 | lambda c: f"There is a large {c} in the scene.", 125 | lambda c: f"a bad photo of a {c}.", 126 | lambda c: f"a photo of many {c}.", 127 | lambda c: f"a sculpture of a {c}.", 128 | lambda c: f"a photo of the hard to see {c}.", 129 | lambda c: f"a low resolution photo of the {c}.", 130 | lambda c: f"a rendering of a {c}.", 131 | lambda c: f"graffiti of a {c}.", 132 | lambda c: f"a bad photo of the {c}.", 133 | lambda c: f"a cropped photo of the {c}.", 134 | lambda c: f"a tattoo of a {c}.", 135 | lambda c: f"the embroidered {c}.", 136 | lambda c: f"a photo of a hard to see {c}.", 137 | lambda c: f"a bright photo of a {c}.", 138 | lambda c: f"a photo of a clean {c}.", 139 | lambda c: f"a photo of a dirty {c}.", 140 | lambda c: f"a dark photo of the {c}.", 141 | lambda c: f"a drawing of a {c}.", 142 | lambda c: f"a photo of my {c}.", 143 | lambda c: f"the plastic {c}.", 144 | lambda c: f"a photo of the cool {c}.", 145 | lambda c: f"a close-up photo of a {c}.", 146 | lambda c: f"a black and white photo of the {c}.", 147 | lambda c: f"a painting of the {c}.", 148 | lambda c: f"a painting of a {c}.", 149 | lambda c: f"a pixelated photo of the {c}.", 150 | lambda c: f"a sculpture of the {c}.", 151 | lambda c: f"a bright photo of the {c}.", 152 | lambda c: f"a cropped photo of a {c}.", 153 | lambda c: f"a plastic {c}.", 154 | lambda c: f"a photo of the dirty {c}.", 155 | lambda c: f"a jpeg corrupted photo of a {c}.", 156 | lambda c: f"a blurry photo of the {c}.", 157 | lambda c: f"a photo of the {c}.", 158 | lambda c: f"a good photo of the {c}.", 159 | lambda c: f"a rendering of the {c}.", 160 | lambda c: f"a {c} in a video game.", 161 | lambda c: f"a photo of one {c}.", 162 | lambda c: f"a doodle of a {c}.", 163 | lambda c: f"a close-up photo of the {c}.", 164 | lambda c: f"a photo of a {c}.", 165 | lambda c: f"the origami {c}.", 166 | lambda c: f"the {c} in a video game.", 167 | lambda c: f"a sketch of a {c}.", 168 | lambda c: f"a doodle of the {c}.", 169 | lambda c: f"a origami {c}.", 170 | lambda c: f"a low resolution photo of a {c}.", 171 | lambda c: f"the toy {c}.", 172 | lambda c: f"a rendition of the {c}.", 173 | lambda c: f"a photo of the clean {c}.", 174 | lambda c: f"a photo of a large {c}.", 175 | lambda c: f"a rendition of a {c}.", 176 | lambda c: f"a photo of a nice {c}.", 177 | lambda c: f"a photo of a weird {c}.", 178 | lambda c: f"a blurry photo of a {c}.", 179 | lambda c: f"a cartoon {c}.", 180 | lambda c: f"art of a {c}.", 181 | lambda c: f"a sketch of the {c}.", 182 | lambda c: f"a embroidered {c}.", 183 | lambda c: f"a pixelated photo of a {c}.", 184 | lambda c: f"itap of the {c}.", 185 | lambda c: f"a jpeg corrupted photo of the {c}.", 186 | lambda c: f"a good photo of a {c}.", 187 | lambda c: f"a plushie {c}.", 188 | lambda c: f"a photo of the nice {c}.", 189 | lambda c: f"a photo of the small {c}.", 190 | lambda c: f"a photo of the weird {c}.", 191 | lambda c: f"the cartoon {c}.", 192 | lambda c: f"art of the {c}.", 193 | lambda c: f"a drawing of the {c}.", 194 | lambda c: f"a photo of the large {c}.", 195 | lambda c: f"a black and white photo of a {c}.", 196 | lambda c: f"the plushie {c}.", 197 | lambda c: f"a dark photo of a {c}.", 198 | lambda c: f"itap of a {c}.", 199 | lambda c: f"graffiti of the {c}.", 200 | lambda c: f"a toy {c}.", 201 | lambda c: f"itap of my {c}.", 202 | lambda c: f"a photo of a cool {c}.", 203 | lambda c: f"a photo of a small {c}.", 204 | lambda c: f"a tattoo of the {c}.", 205 | ) 206 | -------------------------------------------------------------------------------- /lib/utils/test_time_augmentation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | from detectron2.data.detection_utils import read_image 6 | from detectron2.modeling import DatasetMapperTTA 7 | from fvcore.transforms import HFlipTransform 8 | from torch import nn 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | __all__ = [ 12 | "SemanticSegmentorWithTTA", 13 | ] 14 | 15 | 16 | class SemanticSegmentorWithTTA(nn.Module): 17 | """ 18 | A SemanticSegmentor with test-time augmentation enabled. 19 | Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. 20 | """ 21 | 22 | def __init__(self, cfg, model, tta_mapper=None, batch_size=1): 23 | """ 24 | Args: 25 | cfg (CfgNode): 26 | model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. 27 | tta_mapper (callable): takes a dataset dict and returns a list of 28 | augmented versions of the dataset dict. Defaults to 29 | `DatasetMapperTTA(cfg)`. 30 | batch_size (int): batch the augmented images into this batch size for inference. 31 | """ 32 | super().__init__() 33 | if isinstance(model, DistributedDataParallel): 34 | model = model.module 35 | self.cfg = cfg.clone() 36 | 37 | self.model = model 38 | 39 | if tta_mapper is None: 40 | tta_mapper = DatasetMapperTTA(cfg) 41 | self.tta_mapper = tta_mapper 42 | self.batch_size = batch_size 43 | 44 | def __call__(self, batched_inputs): 45 | """ 46 | Same input/output format as :meth:`SemanticSegmentor.forward` 47 | """ 48 | 49 | def _maybe_read_image(dataset_dict): 50 | ret = copy.copy(dataset_dict) 51 | if "image" not in ret: 52 | image = read_image(ret.pop("file_name"), self.model.input_format) 53 | image = torch.from_numpy( 54 | np.ascontiguousarray(image.transpose(2, 0, 1)) 55 | ) # CHW 56 | ret["image"] = image 57 | if "height" not in ret and "width" not in ret: 58 | ret["height"] = image.shape[1] 59 | ret["width"] = image.shape[2] 60 | return ret 61 | 62 | processed_results = [] 63 | for x in batched_inputs: 64 | result = self._inference_one_image(_maybe_read_image(x)) 65 | processed_results.append(result) 66 | return processed_results 67 | 68 | def _inference_one_image(self, input): 69 | """ 70 | Args: 71 | input (dict): one dataset dict with "image" field being a CHW tensor 72 | Returns: 73 | dict: one output dict 74 | """ 75 | orig_shape = (input["height"], input["width"]) 76 | augmented_inputs, tfms = self._get_augmented_inputs(input) 77 | 78 | final_predictions = None 79 | count_predictions = 0 80 | for input, tfm in zip(augmented_inputs, tfms): 81 | count_predictions += 1 82 | with torch.no_grad(): 83 | if final_predictions is None: 84 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 85 | final_predictions = ( 86 | self.model([input])[0].pop("sem_seg").flip(dims=[2]) 87 | ) 88 | else: 89 | final_predictions = self.model([input])[0].pop("sem_seg") 90 | else: 91 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 92 | final_predictions += ( 93 | self.model([input])[0].pop("sem_seg").flip(dims=[2]) 94 | ) 95 | else: 96 | final_predictions += self.model([input])[0].pop("sem_seg") 97 | 98 | final_predictions = final_predictions / count_predictions 99 | return {"sem_seg": final_predictions} 100 | 101 | def _get_augmented_inputs(self, input): 102 | augmented_inputs = self.tta_mapper(input) 103 | tfms = [x.pop("transforms") for x in augmented_inputs] 104 | return augmented_inputs, tfms 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | scipy 3 | shapely 4 | timm 5 | h5py 6 | submitit 7 | scikit-image 8 | opencv-python-headless 9 | open_clip_torch 10 | progressbar 11 | tensorboard==2.13.0 12 | webdataset 13 | ftfy 14 | regex 15 | torchshow 16 | nltk 17 | lvis 18 | cityscapesscripts 19 | pyarrow 20 | fastparquet 21 | torchmetrics 22 | opencv-python 23 | ninja 24 | # git+https://github.com/facebookresearch/detectron2.git 25 | git+https://github.com/cocodataset/panopticapi.git --------------------------------------------------------------------------------