├── .gitignore ├── README.md ├── __init__.py ├── assets ├── images │ ├── dog.jpg │ └── teaser.jpg ├── requirements │ ├── requirements.txt │ └── requirements_custom.txt └── videos │ ├── example1.mp4 │ └── example2.mp4 ├── configs ├── find │ ├── davitd3_llama_lang.yaml │ ├── davitd3_unicl_lang.yaml │ ├── davitd5_llama_lang.yaml │ ├── davitd5_unicl_lang.yaml │ ├── focall_llama_lang.yaml │ ├── focall_unicl_lang.yaml │ ├── focalt_llama_lang.yaml │ ├── focalt_unicl_lang.yaml │ ├── samb_llama_lang.yaml │ └── samb_unicl_lang.yaml ├── seem │ ├── davitd3_unicl_lang_v0.yaml │ ├── davitd3_unicl_lang_v1.yaml │ ├── davitd5_unicl_lang_v0.yaml │ ├── davitd5_unicl_lang_v1.yaml │ ├── focall_unicl_lang_demo.yaml │ ├── focall_unicl_lang_v0.yaml │ ├── focall_unicl_lang_v1.yaml │ ├── focalt_unicl_lang_demo.yaml │ ├── focalt_unicl_lang_v0.yaml │ ├── focalt_unicl_lang_v1.yaml │ ├── samvitb_unicl_lang_v1.yaml │ └── samvitl_unicl_lang_v1.yaml └── xdecoder │ ├── davitd3_unicl_lang.yaml │ ├── davitd5_unicl_lang.yaml │ ├── focall_unicl_lang.yaml │ └── focalt_unicl_lang.yaml ├── datasets ├── __init__.py ├── build.py ├── dataset_mappers │ ├── __init__.py │ ├── bdd_semseg_dataset_mapper.py │ ├── coco_instance_new_baseline_dataset_mapper.py │ ├── coco_language_interleave_dataset_mapper.py │ ├── coco_panoptic_interactive_dataset_mapper.py │ ├── coco_panoptic_interleave_dataset_mapper.py │ ├── coco_panoptic_new_baseline_dataset_mapper.py │ ├── davis_dataset_mapper.py │ ├── davis_dataset_mapper_ix.py │ ├── grounding_coco_entity_mapper.py │ ├── imagenet_dataset_mapper.py │ ├── mask_former_instance_dataset_mapper.py │ ├── mask_former_panoptic_dataset_mapper.py │ ├── mask_former_semantic_dataset_mapper.py │ ├── pascalvoc_dataset_mapper_ix.py │ ├── refcoco_dataset_mapper.py │ ├── sbd_dataset_mapper.py │ ├── scannet_dataset_mapper.py │ ├── scannet_pano_dataset_mapper.py │ ├── sunrgbd_dataset_mapper.py │ ├── vlp_coco_entity_mapper.py │ ├── vlp_dataset_mapper.py │ ├── vlp_interactive_dataset_mapper.py │ ├── vlp_interleave_dataset_mapper.py │ └── ytvos_dataset_mapper.py ├── evaluation │ ├── __init__.py │ ├── captioning_evaluation.py │ ├── classification_evaluation.py │ ├── grounding_evaluation.py │ ├── instance_evaluation.py │ ├── interactive_evaluation.py │ ├── interleave_evaluation.py │ ├── panoptic_evaluation.py │ ├── retrieval_evaluation.py │ └── segmentation_evaluation.py ├── refer.py ├── registration │ ├── __init__.py │ ├── register_ade20k_full.py │ ├── register_ade20k_instance.py │ ├── register_ade20k_panoptic.py │ ├── register_bdd100k_panoseg.py │ ├── register_bdd100k_semseg.py │ ├── register_coco_lvis_panoptic_annos_caption_grounding.py │ ├── register_coco_lvis_panoptic_annos_caption_grounding_entity.py │ ├── register_coco_panoptic_annos_caption.py │ ├── register_coco_panoptic_annos_caption_grounding.py │ ├── register_coco_panoptic_annos_semseg.py │ ├── register_coco_stuff_10k.py │ ├── register_davis_dataset.py │ ├── register_davis_ixeval.py │ ├── register_grounding_coco_entity.py │ ├── register_imagenet_cls.py │ ├── register_pascalvoc_eval.py │ ├── register_refcoco_dataset.py │ ├── register_sbd_eval.py │ ├── register_scannet_panoptic.py │ ├── register_scannet_semseg.py │ ├── register_sunrgbd_semseg.py │ ├── register_vlp_coco_entity.py │ ├── register_vlp_coco_interleave.py │ ├── register_vlp_datasets.py │ └── register_ytvos_dataset.py ├── semseg_loader.py ├── utils │ ├── refcoco2json.py │ └── refer.py └── visual_sampler │ ├── __init__.py │ ├── circle.py │ ├── mask_generators.py │ ├── point.py │ ├── polygon.py │ ├── sampler.py │ ├── scribble.py │ └── simpleclick_sampler.py ├── demo ├── __init__.py └── find │ ├── __init__.py │ ├── arial.ttf │ ├── demo_interleave_llama.py │ └── utils.py ├── entry.py ├── modeling ├── BaseModel.py ├── __init__.py ├── architectures │ ├── __init__.py │ ├── build.py │ ├── find_model.py │ ├── seem_model_v0.py │ ├── seem_model_v1.py │ └── xdecoder_model.py ├── body │ ├── __init__.py │ ├── build.py │ └── xdecoder_head.py ├── interface │ ├── __init__.py │ ├── build.py │ ├── find.py │ ├── operator │ │ ├── __init__.py │ │ ├── attention.py │ │ └── modules.py │ ├── prototype │ │ ├── __init__.py │ │ ├── attention_data_struct_ging.py │ │ ├── attention_data_struct_seemv0.py │ │ └── attention_data_struct_seemv1.py │ ├── seem_v0.py │ ├── seem_v1.py │ └── xdecoder.py ├── language │ ├── LangEncoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── modeling_llama.py │ │ └── transformer.py │ ├── Tokenizer │ │ ├── __init__.py │ │ └── custom_tokenizer.py │ ├── __init__.py │ ├── build.py │ ├── llamaencoder.py │ ├── loss.py │ ├── misc.py │ ├── vlpencoder.py │ └── vlpencoder_v1.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── criterion.py │ ├── matcher.py │ ├── point_features.py │ ├── position_encoding.py │ └── postprocessing.py ├── utils │ ├── __init__.py │ ├── attention.py │ ├── box_ops.py │ ├── config.py │ ├── interactive.py │ └── misc.py └── vision │ ├── backbone │ ├── __init__.py │ ├── backbone.py │ ├── build.py │ ├── common.py │ ├── davit.py │ ├── focal.py │ ├── focal_dw.py │ └── vit.py │ └── encoder │ ├── __init__.py │ ├── build.py │ ├── ops │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn_cpu.h │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp │ └── test.py │ ├── transformer_blocks.py │ ├── transformer_encoder_deform.py │ └── transformer_encoder_fpn.py ├── pipeline ├── FINDPipeline.py ├── XDecoderPipeline.py ├── __init__.py └── utils │ └── misc.py ├── trainer ├── __init__.py ├── default_trainer.py ├── distributed_trainer.py ├── utils │ ├── __init__.py │ ├── hook.py │ ├── misc.py │ ├── mpi_adapter.py │ └── serialization.py ├── utils_trainer.py └── xdecoder_trainer.py ├── utils ├── Config.py ├── __init__.py ├── arguments.py ├── constants.py ├── dataset.py ├── distributed.py ├── misc.py ├── model.py ├── prompt_engineering.py └── visualizer.py └── xy_utils ├── __init__.py ├── annotation ├── annot_interleave_retrieval.py └── find_bench_stat.py ├── evaluation ├── __init__.py ├── compute_grin_visual_features.py ├── eval_gsam_grounding_entity.py ├── eval_seem_interleave_segmentation.py ├── eval_xdecoder_interleave_retrieval.py ├── sam_interactive_best.py └── sam_interactive_iou_box.py ├── gpt4 └── generate_class_description.py ├── image2html ├── __init__.py ├── utils.py └── visualizer.py └── visualization ├── visualize_coco_gpt4_caption_train.py └── visualize_coco_gpt4_caption_val.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ project files 2 | .idea 3 | *.iml 4 | out 5 | gen 6 | 7 | ### Vim template 8 | [._]*.s[a-w][a-z] 9 | [._]s[a-w][a-z] 10 | *.un~ 11 | Session.vim 12 | .netrwhist 13 | *~ 14 | 15 | ### IPythonNotebook template 16 | # Temporary data 17 | .ipynb_checkpoints/ 18 | 19 | ### Python template 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | env/ 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | #lib/ 38 | #lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *,cover 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | *.ipynb 80 | *.params 81 | # *.json 82 | .vscode/ 83 | *.code-workspace/ 84 | 85 | lib/pycocotools/_mask.c 86 | lib/nms/cpu_nms.c 87 | 88 | output 89 | OUTPUT 90 | OUTPUT/* 91 | models/* 92 | DATASET 93 | DATASET/* 94 | external/ 95 | MODELS 96 | MODELS/* 97 | 98 | draws/ 99 | plot/ 100 | 101 | amlt/* 102 | exps/* 103 | 104 | *venv/* 105 | *.pt 106 | *.pth 107 | *.da 108 | 109 | *scripts/* -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/__init__.py -------------------------------------------------------------------------------- /assets/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/assets/images/dog.jpg -------------------------------------------------------------------------------- /assets/images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/assets/images/teaser.jpg -------------------------------------------------------------------------------- /assets/requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | pillow==9.4.0 2 | opencv-python==4.8.1.78 3 | pyyaml==6.0.1 4 | json_tricks==3.17.3 5 | yacs==0.1.8 6 | scikit-learn==1.3.1 7 | pandas==2.0.3 8 | timm==0.4.12 9 | numpy==1.23.1 10 | einops==0.7.0 11 | fvcore==0.1.5.post20221221 12 | transformers==4.34.0 13 | sentencepiece==0.1.99 14 | ftfy==6.1.1 15 | regex==2023.10.3 16 | nltk==3.8.1 17 | vision-datasets==0.2.2 18 | cython==3.0.2 19 | pycocotools==2.0.7 20 | diffdist==0.1 21 | pyarrow==13.0.0 22 | cityscapesscripts==2.2.2 23 | shapely==1.8.0 24 | scikit-image==0.21.0 25 | mup==1.0.0 26 | accelerate==0.23.0 27 | kornia==0.7.0 28 | deepspeed==0.10.3 29 | wandb==0.15.12 30 | infinibatch==0.1.1 31 | gradio==3.42.0 -------------------------------------------------------------------------------- /assets/requirements/requirements_custom.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/arogozhnikov/einops.git 2 | git+https://github.com/MaureenZOU/detectron2-xyz.git 3 | git+https://github.com/openai/whisper.git 4 | git+https://github.com/cocodataset/panopticapi.git -------------------------------------------------------------------------------- /assets/videos/example1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/assets/videos/example1.mp4 -------------------------------------------------------------------------------- /assets/videos/example2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/assets/videos/example2.mp4 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import registration 2 | from .build import build_train_dataloader, build_eval_dataloader, build_evaluator -------------------------------------------------------------------------------- /datasets/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper 2 | from .coco_panoptic_interleave_dataset_mapper import COCOPanopticInterleaveDatasetMapper 3 | from .coco_language_interleave_dataset_mapper import COCOLanguageInterleaveDatasetMapper 4 | from .coco_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper 5 | from .coco_panoptic_new_baseline_dataset_mapper import COCOPanopticNewBaselineDatasetMapper 6 | from .mask_former_instance_dataset_mapper import MaskFormerInstanceDatasetMapper 7 | from .mask_former_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapper 8 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper 9 | from .imagenet_dataset_mapper import ImageNetDatasetMapper 10 | from .vlp_dataset_mapper import VLPreDatasetMapper 11 | from .sunrgbd_dataset_mapper import SunRGBDSegDatasetMapper 12 | from .scannet_dataset_mapper import ScanNetSegDatasetMapper 13 | from .bdd_semseg_dataset_mapper import BDDSemDatasetMapper 14 | from .scannet_pano_dataset_mapper import ScanNetPanoDatasetMapper 15 | from .refcoco_dataset_mapper import RefCOCODatasetMapper 16 | from .pascalvoc_dataset_mapper_ix import PascalVOCSegDatasetMapperIX 17 | from .grounding_coco_entity_mapper import GroundingCOCOEntityDatasetMapper 18 | from .vlp_interactive_dataset_mapper import VLPreInteractiveDatasetMapper 19 | from .vlp_interleave_dataset_mapper import VLPreCOCOInterleaveDatasetMapper 20 | from .vlp_coco_entity_mapper import VLPreCOCOEntityDatasetMapper 21 | from .ytvos_dataset_mapper import YTVOSDatasetMapper 22 | from .sbd_dataset_mapper import SBDDatasetMapper 23 | from .davis_dataset_mapper_ix import DAVISDatasetMapperIX 24 | from .davis_dataset_mapper import DAVISDatasetMapper -------------------------------------------------------------------------------- /datasets/dataset_mappers/bdd_semseg_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["BDDSemDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class BDDSemDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def read_semseg(self, file_name): 76 | if '.png' in file_name: 77 | semseg = np.asarray(Image.open(file_name)) 78 | elif '.mat' in file_name: 79 | semseg = scipy.io.loadmat(file_name)['LabelMap'] 80 | return semseg 81 | 82 | def __call__(self, dataset_dict): 83 | """ 84 | Args: 85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 86 | 87 | Returns: 88 | dict: a format that builtin models in detectron2 accept 89 | """ 90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 91 | file_name = dataset_dict['file_name'] 92 | semseg_name = dataset_dict['sem_seg_file_name'] 93 | image = Image.open(file_name).convert('RGB') 94 | 95 | dataset_dict['width'] = image.size[0] 96 | dataset_dict['height'] = image.size[1] 97 | 98 | if self.is_train == False: 99 | image = self.transform(image) 100 | image = torch.from_numpy(np.asarray(image).copy()) 101 | image = image.permute(2,0,1) 102 | 103 | semseg = self.read_semseg(semseg_name) 104 | semseg = torch.from_numpy(semseg.astype(np.int32)) 105 | dataset_dict['image'] = image 106 | dataset_dict['semseg'] = semseg 107 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/davis_dataset_mapper_ix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py 3 | import copy 4 | import os 5 | 6 | import cv2 7 | import scipy.io 8 | import numpy as np 9 | from scipy.io import loadmat 10 | from PIL import Image 11 | 12 | import torch 13 | from torchvision import transforms 14 | from detectron2.structures import BitMasks, Boxes, Instances 15 | 16 | from modeling.utils import configurable 17 | from ..visual_sampler import build_shape_sampler 18 | 19 | __all__ = ["DAVISDatasetMapperIX"] 20 | 21 | 22 | # This is specifically designed for the COCO dataset. 23 | class DAVISDatasetMapperIX: 24 | """ 25 | A callable which takes a dataset dict in Detectron2 Dataset format, 26 | and map it into a format used by MaskFormer. 27 | 28 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 29 | 30 | The callable currently does the following: 31 | 32 | 1. Read the image from "file_name" 33 | 2. Applies geometric transforms to the image and annotation 34 | 3. Find and applies suitable cropping to the image and annotation 35 | 4. Prepare image and annotation to Tensors 36 | """ 37 | 38 | @configurable 39 | def __init__( 40 | self, 41 | is_train=True, 42 | dataset_name='', 43 | min_size_test=None, 44 | max_size_test=None, 45 | shape_sampler=None, 46 | ): 47 | """ 48 | NOTE: this interface is experimental. 49 | Args: 50 | is_train: for training or inference 51 | augmentations: a list of augmentations or deterministic transforms to apply 52 | tfm_gens: data augmentation 53 | image_format: an image format supported by :func:`detection_utils.read_image`. 54 | """ 55 | self.is_train = is_train 56 | self.dataset_name = dataset_name 57 | self.min_size_test = min_size_test 58 | self.max_size_test = max_size_test 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC, max_size=max_size_test)) 62 | self.transform = transforms.Compose(t) 63 | self.shape_sampler = shape_sampler 64 | 65 | @classmethod 66 | def from_config(cls, cfg, is_train=True, dataset_name=''): 67 | shape_sampler = build_shape_sampler(cfg, is_train=is_train, mode=dataset_name.split('_')[-1]) 68 | ret = { 69 | "is_train": is_train, 70 | "dataset_name": dataset_name, 71 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 72 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 73 | "shape_sampler": shape_sampler, 74 | } 75 | return ret 76 | 77 | def __call__(self, dataset_dict): 78 | """ 79 | Args: 80 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 81 | 82 | Returns: 83 | dict: a format that builtin models in detectron2 accept 84 | """ 85 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 86 | file_name = dataset_dict['file_name'] 87 | mask_name = dataset_dict['mask_name'] 88 | image = Image.open(file_name).convert('RGB') 89 | 90 | dataset_dict['width'] = image.size[0] 91 | dataset_dict['height'] = image.size[1] 92 | 93 | if self.is_train == False: 94 | image = self.transform(image) 95 | image = torch.from_numpy(np.asarray(image).copy()) 96 | image = image.permute(2,0,1) 97 | 98 | instances_mask = np.max(cv2.imread(mask_name).astype(np.int32), axis=2) 99 | instances_mask[instances_mask > 0] = 1 100 | 101 | instances = Instances(image.shape[-2:]) 102 | _,h,w = image.shape 103 | # sbd dataset only has one gt mask. 104 | masks = [cv2.resize(instances_mask.astype(np.uint8), (w,h), interpolation=cv2.INTER_CUBIC)] 105 | masks = BitMasks( 106 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 107 | ) 108 | instances.gt_masks = masks 109 | instances.gt_boxes = masks.get_bounding_boxes() 110 | spatial_query_utils = self.shape_sampler(instances) 111 | 112 | dataset_dict['spatial_query'] = spatial_query_utils 113 | dataset_dict['instances'] = instances 114 | dataset_dict['image'] = image 115 | dataset_dict['gt_masks_orisize'] = torch.from_numpy(instances_mask).bool()[None,] # (nm,h,w) 116 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/imagenet_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | from PIL import Image 10 | # import logging 11 | 12 | import cv2 13 | import numpy as np 14 | 15 | import torch 16 | from torchvision import transforms 17 | 18 | from modeling.utils import configurable 19 | 20 | __all__ = ["ImageNetDatasetMapper"] 21 | 22 | 23 | # This is specifically designed for the COCO dataset. 24 | class ImageNetDatasetMapper: 25 | """ 26 | A callable which takes a dataset dict in Detectron2 Dataset format, 27 | and map it into a format used by MaskFormer. 28 | 29 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 30 | 31 | The callable currently does the following: 32 | 33 | 1. Read the image from "file_name" 34 | 2. Applies geometric transforms to the image and annotation 35 | 3. Find and applies suitable cropping to the image and annotation 36 | 4. Prepare image and annotation to Tensors 37 | """ 38 | 39 | @configurable 40 | def __init__( 41 | self, 42 | is_train=True, 43 | size_train=None, 44 | size_test=None, 45 | size_crop=None, 46 | ): 47 | """ 48 | NOTE: this interface is experimental. 49 | Args: 50 | is_train: for training or inference 51 | augmentations: a list of augmentations or deterministic transforms to apply 52 | tfm_gens: data augmentation 53 | image_format: an image format supported by :func:`detection_utils.read_image`. 54 | """ 55 | self.is_train = is_train 56 | self.size_train = size_train 57 | self.size_test = size_test 58 | self.size_crop = size_crop 59 | 60 | t = [] 61 | t.append(transforms.Resize(size_crop, interpolation=Image.BICUBIC)) 62 | t.append(transforms.CenterCrop(size_test)) 63 | self.transform = transforms.Compose(t) 64 | 65 | @classmethod 66 | def from_config(cls, cfg, is_train=True): 67 | ret = { 68 | "is_train": is_train, 69 | "size_train": cfg['INPUT']['SIZE_TRAIN'], 70 | "size_test": cfg['INPUT']['SIZE_TEST'], 71 | "size_crop": cfg['INPUT']['SIZE_CROP'] 72 | } 73 | return ret 74 | 75 | def __call__(self, dataset_dict): 76 | """ 77 | Args: 78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 79 | 80 | Returns: 81 | dict: a format that builtin models in detectron2 accept 82 | """ 83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 84 | file_name = dataset_dict['file_name'] 85 | image = Image.open(file_name).convert('RGB') 86 | 87 | if self.is_train == False: 88 | image = self.transform(image) 89 | image = torch.from_numpy(np.asarray(image).copy()) 90 | image = image.permute(2,0,1) 91 | 92 | dataset_dict['image'] = image 93 | dataset_dict['height'] = image.shape[1] 94 | dataset_dict['width'] = image.shape[2] 95 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/sbd_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py 3 | import copy 4 | import os 5 | 6 | import cv2 7 | import scipy.io 8 | import numpy as np 9 | from scipy.io import loadmat 10 | from PIL import Image 11 | 12 | import torch 13 | from torchvision import transforms 14 | from detectron2.structures import BitMasks, Boxes, Instances 15 | 16 | from modeling.utils import configurable 17 | from ..visual_sampler import build_shape_sampler 18 | 19 | __all__ = ["SBDDatasetMapper"] 20 | 21 | 22 | # This is specifically designed for the COCO dataset. 23 | class SBDDatasetMapper: 24 | """ 25 | A callable which takes a dataset dict in Detectron2 Dataset format, 26 | and map it into a format used by MaskFormer. 27 | 28 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 29 | 30 | The callable currently does the following: 31 | 32 | 1. Read the image from "file_name" 33 | 2. Applies geometric transforms to the image and annotation 34 | 3. Find and applies suitable cropping to the image and annotation 35 | 4. Prepare image and annotation to Tensors 36 | """ 37 | 38 | @configurable 39 | def __init__( 40 | self, 41 | is_train=True, 42 | dataset_name='', 43 | min_size_test=None, 44 | max_size_test=None, 45 | shape_sampler=None, 46 | ): 47 | """ 48 | NOTE: this interface is experimental. 49 | Args: 50 | is_train: for training or inference 51 | augmentations: a list of augmentations or deterministic transforms to apply 52 | tfm_gens: data augmentation 53 | image_format: an image format supported by :func:`detection_utils.read_image`. 54 | """ 55 | self.is_train = is_train 56 | self.dataset_name = dataset_name 57 | self.min_size_test = min_size_test 58 | self.max_size_test = max_size_test 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC, max_size=max_size_test)) 62 | self.transform = transforms.Compose(t) 63 | self.shape_sampler = shape_sampler 64 | 65 | @classmethod 66 | def from_config(cls, cfg, is_train=True, dataset_name=''): 67 | shape_sampler = build_shape_sampler(cfg, is_train=is_train, mode=dataset_name.split('_')[-1]) 68 | ret = { 69 | "is_train": is_train, 70 | "dataset_name": dataset_name, 71 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 72 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 73 | "shape_sampler": shape_sampler, 74 | } 75 | return ret 76 | 77 | def __call__(self, dataset_dict): 78 | """ 79 | Args: 80 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 81 | 82 | Returns: 83 | dict: a format that builtin models in detectron2 accept 84 | """ 85 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 86 | file_name = dataset_dict['file_name'] 87 | inst_info_name = dataset_dict['inst_info_name'] 88 | inst_id = dataset_dict['inst_id'] 89 | image = Image.open(file_name).convert('RGB') 90 | 91 | dataset_dict['width'] = image.size[0] 92 | dataset_dict['height'] = image.size[1] 93 | 94 | if self.is_train == False: 95 | image = self.transform(image) 96 | image = torch.from_numpy(np.asarray(image).copy()) 97 | image = image.permute(2,0,1) 98 | 99 | instances_mask = loadmat(str(inst_info_name))['GTinst'][0][0][0].astype(np.int32) 100 | instances_mask[instances_mask != inst_id] = 0 101 | instances_mask[instances_mask > 0] = 1 102 | 103 | instances = Instances(image.shape[-2:]) 104 | _,h,w = image.shape 105 | # sbd dataset only has one gt mask. 106 | masks = [cv2.resize(instances_mask.astype(np.uint8), (w,h), interpolation=cv2.INTER_CUBIC)] 107 | masks = BitMasks( 108 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 109 | ) 110 | instances.gt_masks = masks 111 | instances.gt_boxes = masks.get_bounding_boxes() 112 | spatial_query_utils = self.shape_sampler(instances) 113 | 114 | dataset_dict['spatial_query'] = spatial_query_utils 115 | dataset_dict['instances'] = instances 116 | dataset_dict['image'] = image 117 | dataset_dict['gt_masks_orisize'] = torch.from_numpy(instances_mask).bool()[None,] # (nm,h,w) 118 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/scannet_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["ScanNetSegDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class ScanNetSegDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def read_semseg(self, file_name): 76 | if '.png' in file_name: 77 | semseg = np.asarray(Image.open(file_name)) 78 | elif '.mat' in file_name: 79 | semseg = scipy.io.loadmat(file_name)['LabelMap'] 80 | return semseg 81 | 82 | def __call__(self, dataset_dict): 83 | """ 84 | Args: 85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 86 | 87 | Returns: 88 | dict: a format that builtin models in detectron2 accept 89 | """ 90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 91 | file_name = dataset_dict['file_name'] 92 | semseg_name = dataset_dict['sem_seg_file_name'] 93 | image = Image.open(file_name).convert('RGB') 94 | 95 | dataset_dict['width'] = image.size[0] 96 | dataset_dict['height'] = image.size[1] 97 | 98 | if self.is_train == False: 99 | image = self.transform(image) 100 | image = torch.from_numpy(np.asarray(image).copy()) 101 | image = image.permute(2,0,1) 102 | 103 | semseg = self.read_semseg(semseg_name) 104 | semseg = torch.from_numpy(semseg.astype(np.int32)) 105 | dataset_dict['image'] = image 106 | dataset_dict['semseg'] = semseg 107 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/scannet_pano_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["ScanNetPanoDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class ScanNetPanoDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def __call__(self, dataset_dict): 76 | """ 77 | Args: 78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 79 | 80 | Returns: 81 | dict: a format that builtin models in detectron2 accept 82 | """ 83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 84 | file_name = dataset_dict['file_name'] 85 | image = Image.open(file_name).convert('RGB') 86 | 87 | dataset_dict['file_name'] = '_'.join(file_name.split('/')[-3:]) # HACK for /tmp file storage on predictions. 88 | dataset_dict['width'] = image.size[0] 89 | dataset_dict['height'] = image.size[1] 90 | 91 | image = self.transform(image) 92 | image = torch.from_numpy(np.asarray(image).copy()) 93 | image = image.permute(2,0,1) 94 | dataset_dict['image'] = image 95 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/sunrgbd_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["SunRGBDSegDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class SunRGBDSegDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def read_semseg(self, file_name): 76 | if '.png' in file_name: 77 | semseg = np.asarray(Image.open(file_name)) 78 | elif '.mat' in file_name: 79 | semseg = scipy.io.loadmat(file_name)['LabelMap'] 80 | return semseg 81 | 82 | def __call__(self, dataset_dict): 83 | """ 84 | Args: 85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 86 | 87 | Returns: 88 | dict: a format that builtin models in detectron2 accept 89 | """ 90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 91 | file_name = dataset_dict['file_name'] 92 | semseg_name = dataset_dict['sem_seg_file_name'] 93 | image = Image.open(file_name).convert('RGB') 94 | 95 | dataset_dict['width'] = image.size[0] 96 | dataset_dict['height'] = image.size[1] 97 | 98 | if self.is_train == False: 99 | image = self.transform(image) 100 | image = torch.from_numpy(np.asarray(image).copy()) 101 | image = image.permute(2,0,1) 102 | 103 | semseg = self.read_semseg(semseg_name) 104 | semseg = torch.from_numpy(semseg.astype(np.int32)) 105 | dataset_dict['image'] = image 106 | dataset_dict['semseg'] = semseg 107 | return dataset_dict -------------------------------------------------------------------------------- /datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .instance_evaluation import * 2 | from .classification_evaluation import * 3 | from .segmentation_evaluation import * 4 | from .retrieval_evaluation import * 5 | from .captioning_evaluation import * 6 | from .panoptic_evaluation import * 7 | from .grounding_evaluation import * 8 | from .interactive_evaluation import * 9 | from .interleave_evaluation import * -------------------------------------------------------------------------------- /datasets/evaluation/classification_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import logging 11 | 12 | from detectron2.evaluation.evaluator import DatasetEvaluator 13 | 14 | from utils.misc import AverageMeter 15 | from utils.distributed import get_world_size 16 | 17 | 18 | @torch.no_grad() 19 | def accuracy(output, target, topk=(1,)): 20 | """Computes the precision@k for the specified values of k""" 21 | if isinstance(output, list): 22 | output = output[-1] 23 | 24 | n_classes = output.size()[1] 25 | maxk = min(max(topk), n_classes) 26 | batch_size = target.size(0) 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 34 | res.append(correct_k.mul_(100.0 / batch_size).item()) 35 | return res 36 | 37 | class ClassificationEvaluator(DatasetEvaluator): 38 | def __init__(self, *args): 39 | self.top1 = AverageMeter() 40 | self.top5 = AverageMeter() 41 | self._logger = logging.getLogger(__name__) 42 | 43 | def reset(self): 44 | self.top1.reset() 45 | self.top5.reset() 46 | 47 | def process(self, inputs, outputs): 48 | logits = torch.stack([o['pred_class'] for o in outputs]) 49 | y = torch.tensor([t['class_id'] for t in inputs], device=logits.device) 50 | prec1, prec5 = accuracy(logits, y, (1, 5)) 51 | self.top1.update(prec1, y.size(0)) 52 | self.top5.update(prec5, y.size(0)) 53 | 54 | def evaluate(self): 55 | if get_world_size() > 1: 56 | tmp_tensor = torch.tensor( 57 | [self.top1.sum, self.top5.sum, self.top1.count], 58 | device=torch.cuda.current_device() 59 | ) 60 | torch.distributed.all_reduce( 61 | tmp_tensor, torch.distributed.ReduceOp.SUM 62 | ) 63 | top1_sum, top5_sum, count = tmp_tensor.tolist() 64 | else: 65 | top1_sum = self.top1.sum 66 | top5_sum = self.top5.sum 67 | count = self.top1.count 68 | 69 | results = {} 70 | scores = { 71 | 'top1': top1_sum / count, 72 | "top5": top5_sum / count 73 | } 74 | results['class'] = scores 75 | self._logger.info(results) 76 | return results 77 | -------------------------------------------------------------------------------- /datasets/evaluation/grounding_evaluation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import logging 8 | import torch 9 | from torchvision.ops import box_iou 10 | 11 | from detectron2.structures import BoxMode 12 | from detectron2.data import MetadataCatalog 13 | from detectron2.utils.comm import all_gather, is_main_process, synchronize 14 | from detectron2.evaluation.evaluator import DatasetEvaluator 15 | 16 | 17 | class GroundingEvaluator(DatasetEvaluator): 18 | """ 19 | Evaluate grounding segmentation metrics. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | dataset_name, 25 | compute_box=False, 26 | distributed=True, 27 | ): 28 | self._logger = logging.getLogger(__name__) 29 | self._dataset_name = dataset_name 30 | self._distributed = distributed 31 | self._cpu_device = torch.device("cpu") 32 | self._compute_box = compute_box 33 | meta = MetadataCatalog.get(dataset_name) 34 | 35 | def reset(self): 36 | self.cum_I = 0 37 | self.cum_U = 0 38 | self.mIoU = 0 39 | self.eval_seg_iou_list = [.5, .6, .7, .8, .9] 40 | self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device) 41 | self.seg_total = 0 42 | if self._compute_box: 43 | self.mIoU_box = 0 44 | self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device) 45 | 46 | @staticmethod 47 | def computeIoU(pred_seg, gd_seg): 48 | I = (pred_seg & gd_seg) 49 | U = (pred_seg | gd_seg) 50 | return I, U 51 | 52 | def process(self, inputs, outputs): 53 | for input, output in zip(inputs, outputs): 54 | pred = output['grounding_mask'] > 0.0 55 | gt = input['groundings']['masks'].bool() 56 | bsi = len(pred) 57 | I, U = self.computeIoU(pred, gt) 58 | self.cum_I += I.sum().cpu() 59 | self.cum_U += U.sum().cpu() 60 | IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6) 61 | self.mIoU += IoU.sum().cpu() 62 | 63 | if self._compute_box: 64 | pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) 65 | gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu() 66 | IoU_box = box_iou(pred_box, gt_box).diagonal() 67 | self.mIoU_box += IoU_box.sum() 68 | 69 | for idx in range(len(self.eval_seg_iou_list)): 70 | eval_seg_iou = self.eval_seg_iou_list[idx] 71 | self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu() 72 | if self._compute_box: 73 | self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu() 74 | self.seg_total += bsi 75 | 76 | def evaluate(self): 77 | if self._distributed: 78 | synchronize() 79 | self.cum_I = torch.stack(all_gather(self.cum_I)).sum() 80 | self.cum_U = torch.stack(all_gather(self.cum_U)).sum() 81 | self.mIoU = torch.stack(all_gather(self.mIoU)).sum() 82 | self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0) 83 | self.seg_total = sum(all_gather(self.seg_total)) 84 | 85 | if self._compute_box: 86 | self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum() 87 | self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0) 88 | if not is_main_process(): 89 | return 90 | 91 | results = {} 92 | for idx in range(len(self.eval_seg_iou_list)): 93 | result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx]) 94 | results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item() 95 | results['cIoU'] = (self.cum_I*100./self.cum_U).item() 96 | results['mIoU'] = (self.mIoU*100./self.seg_total).item() 97 | 98 | if self._compute_box: 99 | for idx in range(len(self.eval_seg_iou_list)): 100 | result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx]) 101 | results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item() 102 | results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item() 103 | 104 | self._logger.info(results) 105 | return {'grounding': results} -------------------------------------------------------------------------------- /datasets/evaluation/instance_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import contextlib 3 | import copy 4 | import io 5 | import itertools 6 | import json 7 | import logging 8 | import numpy as np 9 | import os 10 | import pickle 11 | from collections import OrderedDict 12 | import pycocotools.mask as mask_util 13 | import torch 14 | from pycocotools.coco import COCO 15 | from pycocotools.cocoeval import COCOeval 16 | from tabulate import tabulate 17 | 18 | import detectron2.utils.comm as comm 19 | from detectron2.config import CfgNode 20 | from detectron2.data import MetadataCatalog 21 | from detectron2.data.datasets.coco import convert_to_coco_json 22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco 23 | from detectron2.evaluation.fast_eval_api import COCOeval_opt 24 | from detectron2.structures import Boxes, BoxMode, pairwise_iou 25 | from detectron2.utils.file_io import PathManager 26 | from detectron2.utils.logger import create_small_table 27 | 28 | 29 | # modified from COCOEvaluator for instance segmetnat 30 | class InstanceSegEvaluator(COCOEvaluator): 31 | """ 32 | Evaluate AR for object proposals, AP for instance detection/segmentation, AP 33 | for keypoint detection outputs using COCO's metrics. 34 | See http://cocodataset.org/#detection-eval and 35 | http://cocodataset.org/#keypoints-eval to understand its metrics. 36 | The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means 37 | the metric cannot be computed (e.g. due to no predictions made). 38 | 39 | In addition to COCO, this evaluator is able to support any bounding box detection, 40 | instance segmentation, or keypoint detection dataset. 41 | """ 42 | 43 | def _eval_predictions(self, predictions, img_ids=None): 44 | """ 45 | Evaluate predictions. Fill self._results with the metrics of the tasks. 46 | """ 47 | self._logger.info("Preparing results for COCO format ...") 48 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) 49 | tasks = self._tasks or self._tasks_from_predictions(coco_results) 50 | 51 | # unmap the category ids for COCO 52 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): 53 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id 54 | # all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) 55 | # num_classes = len(all_contiguous_ids) 56 | # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 57 | 58 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} 59 | for result in coco_results: 60 | category_id = result["category_id"] 61 | # assert category_id < num_classes, ( 62 | # f"A prediction has class={category_id}, " 63 | # f"but the dataset only has {num_classes} classes and " 64 | # f"predicted class id should be in [0, {num_classes - 1}]." 65 | # ) 66 | assert category_id in reverse_id_mapping, ( 67 | f"A prediction has class={category_id}, " 68 | f"but the dataset only has class ids in {dataset_id_to_contiguous_id}." 69 | ) 70 | result["category_id"] = reverse_id_mapping[category_id] 71 | 72 | if self._output_dir: 73 | file_path = os.path.join(self._output_dir, "coco_instances_results.json") 74 | self._logger.info("Saving results to {}".format(file_path)) 75 | with PathManager.open(file_path, "w") as f: 76 | f.write(json.dumps(coco_results)) 77 | f.flush() 78 | 79 | if not self._do_evaluation: 80 | self._logger.info("Annotations are not available for evaluation.") 81 | return 82 | 83 | self._logger.info( 84 | "Evaluating predictions with {} COCO API...".format( 85 | "unofficial" if self._use_fast_impl else "official" 86 | ) 87 | ) 88 | for task in sorted(tasks): 89 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" 90 | coco_eval = ( 91 | _evaluate_predictions_on_coco( 92 | self._coco_api, 93 | coco_results, 94 | task, 95 | kpt_oks_sigmas=self._kpt_oks_sigmas, 96 | use_fast_impl=self._use_fast_impl, 97 | img_ids=img_ids, 98 | max_dets_per_image=self._max_dets_per_image, 99 | ) 100 | if len(coco_results) > 0 101 | else None # cocoapi does not handle empty results very well 102 | ) 103 | 104 | res = self._derive_coco_results( 105 | coco_eval, task, class_names=self._metadata.get("thing_classes") 106 | ) 107 | self._results[task] = res 108 | -------------------------------------------------------------------------------- /datasets/evaluation/interactive_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision.ops import box_iou 8 | 9 | from detectron2.structures import BoxMode 10 | from detectron2.data import MetadataCatalog 11 | from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize 12 | from detectron2.evaluation.evaluator import DatasetEvaluator 13 | 14 | 15 | class InteractiveEvaluator(DatasetEvaluator): 16 | """ 17 | Evaluate point interactive IoU metrics. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dataset_name, 23 | output_dir, 24 | max_clicks=20, 25 | iou_iter=1, 26 | compute_box=False, 27 | distributed=True, 28 | ): 29 | self._logger = logging.getLogger(__name__) 30 | self._dataset_name = dataset_name 31 | self._distributed = distributed 32 | self._cpu_device = torch.device("cpu") 33 | self._output_dir = output_dir 34 | 35 | self.max_clicks = max_clicks 36 | self.iou_iter = iou_iter 37 | meta = MetadataCatalog.get(dataset_name) 38 | 39 | def reset(self): 40 | self.iou_list = [] 41 | self.num_samples = 0 42 | self.all_ious = [0.5, 0.8, 0.85, 0.9] 43 | 44 | def process(self, inputs, outputs): 45 | self.iou_list += [o['mask_iou'] for o in outputs] 46 | self.num_samples += len(outputs) 47 | 48 | def compute_noc(self): 49 | def _get_noc(iou_arr, iou_thr): 50 | vals = iou_arr >= iou_thr 51 | return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks 52 | 53 | noc_list = {} 54 | for iou_thr in self.all_ious: 55 | scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list] 56 | noc_list[str(iou_thr)] = scores_arr 57 | 58 | iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1] 59 | noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()} 60 | 61 | if self._distributed: 62 | num_samples = sum(all_gather(self.num_samples)) 63 | noc_list_sum_gather = all_gather(noc_list_sum) 64 | iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu()) 65 | 66 | noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]} 67 | for nlg in noc_list_sum_gather: 68 | for key, value in nlg.items(): 69 | noc_list_sum[key] += value 70 | 71 | pred_noc = {} 72 | if self._distributed and (not is_main_process()): 73 | return pred_noc 74 | 75 | for key, value in noc_list_sum.items(): 76 | pred_noc[key] = value / num_samples 77 | 78 | pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples 79 | return pred_noc 80 | 81 | def evaluate(self): 82 | pred_noc = self.compute_noc() 83 | 84 | if self._distributed and (not is_main_process()): 85 | return 86 | 87 | def draw_iou_curve(iou_list, save_dir): 88 | iou_list = torch.stack(iou_list, dim=0) 89 | iou_list = iou_list.mean(dim=0).cpu().numpy() 90 | 91 | if len(iou_list) > 1: 92 | # draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib 93 | import matplotlib.pyplot as plt 94 | plt.figure() 95 | plt.plot(range(1, self.max_clicks+1), iou_list) 96 | plt.xlabel('Number of clicks') 97 | plt.ylabel('IoU') 98 | 99 | 100 | # create directory if not exist 101 | import os 102 | output_dir = os.path.join(save_dir, 'iou_by_clicks') 103 | if not os.path.exists(output_dir): 104 | os.makedirs(output_dir) 105 | 106 | # get current time and format in 10 digits 107 | import time 108 | current_time = time.time() 109 | current_time = int(current_time) 110 | current_time = str(current_time) 111 | 112 | # save iou curve 113 | plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time))) 114 | 115 | draw_iou_curve(self.iou_list, self._output_dir) 116 | results = {} 117 | for idx in range(len(self.all_ious)): 118 | result_str = 'noc@{}'.format(self.all_ious[idx]) 119 | results[result_str] = pred_noc[str(self.all_ious[idx])] 120 | 121 | results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter'] 122 | 123 | self._logger.info(results) 124 | return {'interactive': results} -------------------------------------------------------------------------------- /datasets/registration/__init__.py: -------------------------------------------------------------------------------- 1 | register_functions = [ 2 | "register_refcoco_dataset", 3 | "register_ade20k_full", 4 | "register_ade20k_panoptic", 5 | "register_coco_stuff_10k", 6 | "register_coco_panoptic_annos_semseg", 7 | "register_coco_panoptic_annos_caption", 8 | "register_coco_panoptic_annos_caption_grounding", 9 | "register_coco_lvis_panoptic_annos_caption_grounding", 10 | "register_coco_lvis_panoptic_annos_caption_grounding_entity", 11 | "register_ade20k_instance", 12 | "register_vlp_datasets", 13 | "register_sunrgbd_semseg", 14 | "register_scannet_semseg", 15 | "register_bdd100k_semseg", 16 | "register_scannet_panoptic", 17 | "register_bdd100k_panoseg", 18 | "register_pascalvoc_eval", 19 | "register_grounding_coco_entity", 20 | "register_vlp_coco_entity", 21 | "register_vlp_coco_interleave", 22 | "register_davis_dataset", 23 | "register_ytvos_dataset", 24 | "register_davis_ixeval", 25 | "register_sbd_eval", 26 | ] 27 | 28 | for func_name in register_functions: 29 | try: 30 | exec(f"from . import {func_name}") 31 | except Exception as e: 32 | print(f"Error with {func_name}: {e}") -------------------------------------------------------------------------------- /datasets/registration/register_ade20k_instance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import json 3 | import logging 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | 8 | from detectron2.data import DatasetCatalog, MetadataCatalog 9 | from detectron2.data.datasets.coco import load_coco_json, register_coco_instances 10 | from detectron2.utils.file_io import PathManager 11 | 12 | ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}] 13 | 14 | 15 | _PREDEFINED_SPLITS = { 16 | # point annotations without masks 17 | "ade20k_instance_train": ( 18 | "ADEChallengeData2016/images/training", 19 | "ADEChallengeData2016/ade20k_instance_train.json", 20 | ), 21 | "ade20k_instance_val": ( 22 | "ADEChallengeData2016/images/validation", 23 | "ADEChallengeData2016/ade20k_instance_val.json", 24 | ), 25 | } 26 | 27 | 28 | def _get_ade_instances_meta(): 29 | thing_ids = [k["id"] for k in ADE_CATEGORIES] 30 | assert len(thing_ids) == 100, len(thing_ids) 31 | # Mapping from the incontiguous ADE category id to an id in [0, 99] 32 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 33 | thing_classes = [k["name"] for k in ADE_CATEGORIES] 34 | ret = { 35 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 36 | "thing_classes": thing_classes, 37 | } 38 | return ret 39 | 40 | 41 | def register_all_ade20k_instance(root): 42 | for key, (image_root, json_file) in _PREDEFINED_SPLITS.items(): 43 | # Assume pre-defined datasets live in `./datasets`. 44 | register_coco_instances( 45 | key, 46 | _get_ade_instances_meta(), 47 | os.path.join(root, json_file) if "://" not in json_file else json_file, 48 | os.path.join(root, image_root), 49 | ) 50 | 51 | 52 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 53 | register_all_ade20k_instance(_root) 54 | -------------------------------------------------------------------------------- /datasets/registration/register_bdd100k_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import BDD_SEM 17 | 18 | __all__ = ["load_scannet_instances", "register_scannet_context"] 19 | 20 | 21 | def load_bdd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load BDD annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | img_folder = os.path.join(dirname, 'images', '10k', split) 31 | img_pths = sorted(glob.glob(os.path.join(img_folder, '*.jpg'))) 32 | 33 | sem_folder = os.path.join(dirname, 'labels', 'sem_seg', 'masks', split) 34 | sem_pths = sorted(glob.glob(os.path.join(sem_folder, '*.png'))) 35 | 36 | assert len(img_pths) == len(sem_pths) 37 | 38 | dicts = [] 39 | for img_pth, sem_pth in zip(img_pths, sem_pths): 40 | r = { 41 | "file_name": img_pth, 42 | "sem_seg_file_name": sem_pth, 43 | "image_id": img_pth.split('/')[-1].split('.')[0], 44 | } 45 | dicts.append(r) 46 | return dicts 47 | 48 | 49 | def register_bdd_context(name, dirname, split, class_names=BDD_SEM): 50 | DatasetCatalog.register(name, lambda: load_bdd_instances(name, dirname, split, class_names)) 51 | MetadataCatalog.get(name).set( 52 | stuff_classes=class_names, 53 | dirname=dirname, 54 | split=split, 55 | ignore_label=[255], 56 | thing_dataset_id_to_contiguous_id={}, 57 | class_offset=0, 58 | keep_sem_bgd=False 59 | ) 60 | 61 | 62 | def register_all_sunrgbd_seg(root): 63 | SPLITS = [ 64 | ("bdd10k_val_sem_seg", "bdd100k", "val"), 65 | ] 66 | 67 | for name, dirname, split in SPLITS: 68 | register_bdd_context(name, os.path.join(root, dirname), split) 69 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 70 | 71 | 72 | _root = os.getenv("DATASET", "datasets") 73 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_davis_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import os 4 | import glob 5 | import json 6 | from typing import List, Tuple, Union 7 | 8 | import cv2 9 | import numpy as np 10 | from scipy.io import loadmat 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | 17 | __all__ = ["load_davis_instances", "register_davis_context"] 18 | 19 | def load_davis_instances(name: str, dirname: str, split: str, year: str): 20 | """ 21 | Load Pascal VOC detection annotations to Detectron2 format. 22 | 23 | Args: 24 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 25 | split (str): one of "train", "test", "val", "trainval" 26 | class_names: list or tuple of class names 27 | """ 28 | meta_txt = os.path.join(dirname, 'ImageSets', year, "{}.txt".format(split)) 29 | meta_json = os.path.join(dirname, 'video_objects_info.json') 30 | meta_json = json.load(open(meta_json))['videos'] 31 | video_names = [line.strip() for line in open(meta_txt).readlines()] 32 | 33 | video_dir = os.path.join(dirname, 'JPEGImages', '480p') 34 | mask_dir = os.path.join(dirname, 'Annotations', '480p') 35 | scibble_dir = os.path.join(dirname, 'Scribbles', '480p') 36 | semantic_dir = os.path.join(dirname, 'Annotations_semantics', '480p') 37 | 38 | dicts = [] 39 | for vid_name in video_names: 40 | objects = meta_json[vid_name]['objects'] 41 | r = { 42 | "file_name": os.path.join(video_dir, vid_name), 43 | "mask_name": os.path.join(mask_dir, vid_name), 44 | "scibble_name": os.path.join(scibble_dir, vid_name), 45 | "semantic_name": os.path.join(semantic_dir, vid_name), 46 | "objects": objects, 47 | } 48 | dicts.append(r) 49 | return dicts 50 | 51 | def register_davis_context(name, dirname, split, year): 52 | load_davis_instances(name, dirname, split, year) 53 | DatasetCatalog.register("{}".format(name), lambda: load_davis_instances(name, dirname, split, year)) 54 | MetadataCatalog.get("{}".format(name)).set( 55 | dirname=dirname, 56 | thing_dataset_id_to_contiguous_id={}, 57 | ) 58 | 59 | def register_all_davis(root): 60 | SPLITS = [ 61 | ("davis17_val", "DAVIS17", "val", "2017"), 62 | ("davis16_val", "DAVIS17", "val", "2016"), 63 | ] 64 | 65 | for name, dirname, split, year in SPLITS: 66 | register_davis_context(name, os.path.join(root, dirname), split, year) 67 | MetadataCatalog.get("{}".format(name)).evaluator_type = None 68 | 69 | _root = os.getenv("DATASET", "datasets") 70 | register_all_davis(_root) -------------------------------------------------------------------------------- /datasets/registration/register_davis_ixeval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import os 4 | import glob 5 | from typing import List, Tuple, Union 6 | 7 | import cv2 8 | import numpy as np 9 | from scipy.io import loadmat 10 | 11 | from detectron2.data import DatasetCatalog, MetadataCatalog 12 | from detectron2.structures import BoxMode 13 | from detectron2.utils.file_io import PathManager 14 | 15 | 16 | __all__ = ["load_davis_instances", "register_davis_context"] 17 | 18 | def load_davis_instances(name: str, dirname: str, mode: str, split: str): 19 | """ 20 | Load Pascal VOC detection annotations to Detectron2 format. 21 | 22 | Args: 23 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 24 | split (str): one of "train", "test", "val", "trainval" 25 | class_names: list or tuple of class names 26 | """ 27 | image_pths = sorted(glob.glob(os.path.join(dirname, "img", "*.jpg"))) 28 | mask_pths = sorted(glob.glob(os.path.join(dirname, "gt", "*.png"))) 29 | assert len(image_pths) == len(mask_pths) 30 | 31 | dicts = [] 32 | for image_pth, mask_pth in zip(image_pths, mask_pths): 33 | r = { 34 | "file_name": image_pth, 35 | "mask_name": mask_pth, 36 | } 37 | dicts.append(r) 38 | return dicts 39 | 40 | def register_davis_context(name, dirname, mode, split): 41 | DatasetCatalog.register("{}_{}".format(name, mode), lambda: load_davis_instances(name, dirname, mode, split)) 42 | MetadataCatalog.get("{}_{}".format(name, mode)).set( 43 | dirname=dirname, 44 | thing_dataset_id_to_contiguous_id={}, 45 | ) 46 | 47 | def register_all_davis(root): 48 | SPLITS = [ 49 | ("openimage600_val", "open-image600", "Point", "val"), 50 | ("openimage600_val", "open-image600", "Scribble", "val"), 51 | ("openimage600_val", "open-image600", "Polygon", "val"), 52 | ("openimage600_val", "open-image600", "Circle", "val"), 53 | ("openimage600_val", "open-image600", "Box", "val"), 54 | ("ade600_val", "ADE600", "Point", "val"), 55 | ("ade600_val", "ADE600", "Scribble", "val"), 56 | ("ade600_val", "ADE600", "Polygon", "val"), 57 | ("ade600_val", "ADE600", "Circle", "val"), 58 | ("ade600_val", "ADE600", "Box", "val"), 59 | ("davis_val", "DAVIS345", "Point", "val"), 60 | ("davis_val", "DAVIS345", "Scribble", "val"), 61 | ("davis_val", "DAVIS345", "Polygon", "val"), 62 | ("davis_val", "DAVIS345", "Circle", "val"), 63 | ("davis_val", "DAVIS345", "Box", "val"), 64 | ("cocomini_val", "COCO_MVal", "Point", "val"), 65 | ("cocomini_val", "COCO_MVal", "Scribble", "val"), 66 | ("cocomini_val", "COCO_MVal", "Polygon", "val"), 67 | ("cocomini_val", "COCO_MVal", "Circle", "val"), 68 | ("cocomini_val", "COCO_MVal", "Box", "val"), 69 | ] 70 | 71 | for name, dirname, mode, split in SPLITS: 72 | register_davis_context(name, os.path.join(root, dirname), mode, split) 73 | MetadataCatalog.get("{}_{}".format(name, mode)).evaluator_type = "interactive" 74 | 75 | _root = os.getenv("DATASET", "datasets") 76 | register_all_davis(_root) -------------------------------------------------------------------------------- /datasets/registration/register_imagenet_cls.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import glob 10 | from typing import List, Tuple, Union 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import IMAGENET_CLASSES, IMAGENET_FOLDER_NAMES 17 | 18 | __all__ = ["load_imagenet_images", "register_imagenet"] 19 | 20 | 21 | def load_imagenet_images(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load ImageNet annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | image_folders = sorted(glob.glob(os.path.join(dirname, split, 'n*'))) 31 | 32 | dicts = [] 33 | for image_folder in image_folders: 34 | folder_name = image_folder.split('/')[-1] 35 | image_pths = sorted(glob.glob(os.path.join(image_folder, "*.JPEG"))) 36 | for img_pth in image_pths: 37 | r = { 38 | "file_name": img_pth, 39 | "class_name": IMAGENET_CLASSES[IMAGENET_FOLDER_NAMES.index(folder_name)], 40 | "class_id": IMAGENET_FOLDER_NAMES.index(folder_name), 41 | } 42 | dicts.append(r) 43 | return dicts 44 | 45 | 46 | def register_imagenet(name, dirname, split, year, class_names=IMAGENET_CLASSES): 47 | DatasetCatalog.register(name, lambda: load_imagenet_images(dirname, split, class_names)) 48 | MetadataCatalog.get(name).set( 49 | thing_classes=list(class_names), dirname=dirname, year=year, split=split 50 | ) 51 | 52 | 53 | def register_all_imagenet(root): 54 | SPLITS = [ 55 | ("imagenet_val", "imagenet", "val", "2012"), 56 | ] 57 | for name, dirname, split, year in SPLITS: 58 | register_imagenet(name, os.path.join(root, dirname), split, year) 59 | MetadataCatalog.get(name).evaluator_type = "classification" 60 | 61 | 62 | _root = os.getenv("DATASET", "datasets") 63 | register_all_imagenet(_root) -------------------------------------------------------------------------------- /datasets/registration/register_pascalvoc_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import os 4 | import glob 5 | from typing import List, Tuple, Union 6 | import xml.etree.ElementTree as ET 7 | 8 | import cv2 9 | import numpy as np 10 | from scipy.io import loadmat 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | 17 | __all__ = ["load_pascalvoc_instances", "register_pascalvoc_context"] 18 | 19 | def get_labels_with_sizes(x): 20 | obj_sizes = np.bincount(x.flatten()) 21 | labels = np.nonzero(obj_sizes)[0].tolist() 22 | labels = [x for x in labels if x != 0] 23 | return labels, obj_sizes[labels].tolist() 24 | 25 | def load_pascalvoc_instances(name: str, dirname: str, mode: str, split: str): 26 | """ 27 | Load Pascal VOC detection annotations to Detectron2 format. 28 | 29 | Args: 30 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 31 | split (str): one of "train", "test", "val", "trainval" 32 | class_names: list or tuple of class names 33 | """ 34 | with PathManager.open(os.path.join(dirname, 'ImageSets', 'Segmentation', split + ".txt")) as f: 35 | fileids = np.loadtxt(f, dtype=np.str) 36 | 37 | dicts = [] 38 | for field in fileids: 39 | anno_path = os.path.join(dirname, "Annotations", "{}.xml".format(field)) 40 | image_path = os.path.join(dirname, "JPEGImages", "{}.jpg".format(field)) 41 | inst_path = os.path.join(dirname, "SegmentationObject", "{}.png".format(field)) 42 | semseg_path = os.path.join(dirname, "SegmentationClass", "{}.png".format(field)) 43 | 44 | instances_mask = cv2.imread(inst_path) 45 | instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) 46 | 47 | objects_ids = np.unique(instances_mask) 48 | objects_ids = [x for x in objects_ids if x != 0 and x != 220] 49 | 50 | slice_size = 5 51 | for i in range(0, len(objects_ids), slice_size): 52 | r = { 53 | "file_name": image_path, 54 | "inst_name": inst_path, 55 | "semseg_name": semseg_path, 56 | "objects_ids": objects_ids[i:i+slice_size], 57 | } 58 | dicts.append(r) 59 | return dicts 60 | 61 | def register_pascalvoc_context(name, dirname, mode, split): 62 | DatasetCatalog.register("{}_{}".format(name, mode), lambda: load_pascalvoc_instances(name, dirname, mode, split)) 63 | MetadataCatalog.get("{}_{}".format(name, mode)).set( 64 | dirname=dirname, 65 | thing_dataset_id_to_contiguous_id={}, 66 | ) 67 | 68 | def register_all_sbd(root): 69 | SPLITS = [ 70 | ("pascalvoc_val", "PascalVOC", "Point", "val"), 71 | ("pascalvoc_val", "PascalVOC", "Scribble", "val"), 72 | ("pascalvoc_val", "PascalVOC", "Polygon", "val"), 73 | ("pascalvoc_val", "PascalVOC", "Circle", "val"), 74 | ("pascalvoc_val", "PascalVOC", "Box", "val"), 75 | ] 76 | 77 | for name, dirname, mode, split in SPLITS: 78 | register_pascalvoc_context(name, os.path.join(root, dirname), mode, split) 79 | MetadataCatalog.get("{}_{}".format(name, mode)).evaluator_type = "interactive" 80 | 81 | _root = os.getenv("DATASET", "datasets") 82 | register_all_sbd(_root) -------------------------------------------------------------------------------- /datasets/registration/register_refcoco_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import json 8 | import os 9 | import collections 10 | 11 | from detectron2.data import DatasetCatalog, MetadataCatalog 12 | from detectron2.data.datasets import load_sem_seg 13 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 14 | from detectron2.utils.file_io import PathManager 15 | 16 | 17 | _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { 18 | # "refcocog_train_umd": ( 19 | # "coco/train2017", # image_root 20 | # "coco/annotations/refcocog_umd_train.json", # annot_root 21 | # ), 22 | # "refcocog_val_google": ( 23 | # "coco/train2017", # image_root 24 | # "coco/annotations/refcocog_google.json", # annot_root 25 | # ), 26 | # "refcocop_val_unc": ( 27 | # "coco/train2017", # image_root 28 | # "coco/annotations/refcocop_unc.json", # annot_root 29 | # ), 30 | # "refcoco_val_unc": ( 31 | # "coco/train2017", # image_root 32 | # "coco/annotations/refcoco_unc.json", # annot_root 33 | # ), 34 | "refcocog_val_umd": ( 35 | "coco/train2017", # image_root 36 | "coco/annotations/refcocog_umd_val.json", # annot_root 37 | ), 38 | } 39 | 40 | 41 | def get_metadata(): 42 | meta = {} 43 | return meta 44 | 45 | 46 | def load_refcoco_json(image_root, annot_json, metadata): 47 | """ 48 | Args: 49 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 50 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". 51 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". 52 | Returns: 53 | list[dict]: a list of dicts in Detectron2 standard format. (See 54 | `Using Custom Datasets `_ ) 55 | """ 56 | 57 | with PathManager.open(annot_json) as f: 58 | json_info = json.load(f) 59 | 60 | # build dictionary for grounding 61 | grd_dict = collections.defaultdict(list) 62 | for grd_ann in json_info['annotations']: 63 | image_id = int(grd_ann["image_id"]) 64 | grd_dict[image_id].append(grd_ann) 65 | 66 | ret = [] 67 | for image in json_info["images"]: 68 | image_id = int(image["id"]) 69 | image_file = os.path.join(image_root, image['file_name']) 70 | grounding_anno = grd_dict[image_id] 71 | ret.append( 72 | { 73 | "file_name": image_file, 74 | "image_id": image_id, 75 | "grounding_info": grounding_anno, 76 | } 77 | ) 78 | assert len(ret), f"No images found in {image_root}!" 79 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] 80 | return ret 81 | 82 | 83 | def register_refcoco( 84 | name, metadata, image_root, annot_json): 85 | DatasetCatalog.register( 86 | name, 87 | lambda: load_refcoco_json(image_root, annot_json, metadata), 88 | ) 89 | MetadataCatalog.get(name).set( 90 | image_root=image_root, 91 | json_file=annot_json, 92 | evaluator_type="grounding_refcoco", 93 | ignore_label=255, 94 | label_divisor=1000, 95 | **metadata, 96 | ) 97 | 98 | 99 | def register_all_refcoco(root): 100 | for ( 101 | prefix, 102 | (image_root, annot_root), 103 | ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): 104 | register_refcoco( 105 | prefix, 106 | get_metadata(), 107 | os.path.join(root, image_root), 108 | os.path.join(root, annot_root), 109 | ) 110 | 111 | 112 | _root = os.getenv("DATASET", "datasets") 113 | register_all_refcoco(_root) 114 | -------------------------------------------------------------------------------- /datasets/registration/register_sbd_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import os 4 | import glob 5 | from typing import List, Tuple, Union 6 | 7 | import numpy as np 8 | from scipy.io import loadmat 9 | 10 | from detectron2.data import DatasetCatalog, MetadataCatalog 11 | from detectron2.structures import BoxMode 12 | from detectron2.utils.file_io import PathManager 13 | 14 | 15 | __all__ = ["load_sbd_instances", "register_sbd_context"] 16 | 17 | def get_labels_with_sizes(x): 18 | obj_sizes = np.bincount(x.flatten()) 19 | labels = np.nonzero(obj_sizes)[0].tolist() 20 | labels = [x for x in labels if x != 0] 21 | return labels, obj_sizes[labels].tolist() 22 | 23 | def load_sbd_instances(name: str, dirname: str, mode: str, split: str): 24 | """ 25 | Load Pascal VOC detection annotations to Detectron2 format. 26 | 27 | Args: 28 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 29 | split (str): one of "train", "test", "val", "trainval" 30 | class_names: list or tuple of class names 31 | """ 32 | with PathManager.open(os.path.join(dirname, split + ".txt")) as f: 33 | fileids = np.loadtxt(f, dtype=np.str) 34 | 35 | dicts = [] 36 | for field in fileids: 37 | image_path = os.path.join(dirname, "img", "{}.jpg".format(field)) 38 | inst_info_path = os.path.join(dirname, "inst", "{}.mat".format(field)) 39 | 40 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) 41 | instances_ids, _ = get_labels_with_sizes(instances_mask) 42 | 43 | for instances_id in instances_ids: 44 | r = { 45 | "file_name": image_path, 46 | "inst_info_name": inst_info_path, 47 | "inst_id": instances_id, 48 | } 49 | dicts.append(r) 50 | return dicts 51 | 52 | def register_sbd_context(name, dirname, mode, split): 53 | DatasetCatalog.register("{}_{}".format(name, mode), lambda: load_sbd_instances(name, dirname, mode, split)) 54 | MetadataCatalog.get("{}_{}".format(name, mode)).set( 55 | dirname=dirname, 56 | thing_dataset_id_to_contiguous_id={}, 57 | ) 58 | 59 | def register_all_sbd(root): 60 | SPLITS = [ 61 | ("sbd_val", "SBD", "Point", "val"), 62 | ("sbd_val", "SBD", "Scribble", "val"), 63 | ("sbd_val", "SBD", "Polygon", "val"), 64 | ("sbd_val", "SBD", "Circle", "val"), 65 | ] 66 | 67 | for name, dirname, mode, split in SPLITS: 68 | register_sbd_context(name, os.path.join(root, dirname), mode, split) 69 | MetadataCatalog.get("{}_{}".format(name, mode)).evaluator_type = "interactive" 70 | 71 | _root = os.getenv("DATASET", "datasets") 72 | register_all_sbd(_root) -------------------------------------------------------------------------------- /datasets/registration/register_scannet_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import os 9 | import glob 10 | from typing import List, Tuple, Union 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import SCAN_37, SCAN_40, SCAN_20 17 | 18 | __all__ = ["load_scannet_instances", "register_scannet_context"] 19 | 20 | name2folder = {"scannet_41_val_seg": "label41", 21 | "scannet_38_val_seg": "label38", 22 | "scannet_21_val_seg": "label21",} 23 | 24 | name2class = {"scannet_41_val_seg": SCAN_40, 25 | "scannet_38_val_seg": SCAN_37, 26 | "scannet_21_val_seg": SCAN_20} 27 | 28 | 29 | def load_scannet_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 30 | """ 31 | Load ScanNet annotations to Detectron2 format. 32 | 33 | Args: 34 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 35 | split (str): one of "train", "test", "val", "trainval" 36 | class_names: list or tuple of class names 37 | """ 38 | with PathManager.open(os.path.join(dirname, "meta", split + ".txt")) as f: 39 | fileids = np.loadtxt(f, dtype=np.str) 40 | 41 | dicts = [] 42 | for field in fileids: 43 | image_dir = os.path.join(dirname, 'images', field[0]) 44 | semseg_dir = image_dir.replace('color', name2folder[name]).replace('jpg', 'png') 45 | r = { 46 | "file_name": image_dir, 47 | "sem_seg_file_name": semseg_dir, 48 | "image_id": semseg_dir.split('/')[-3] + semseg_dir.split('/')[-1].split('.')[0], 49 | } 50 | dicts.append(r) 51 | return dicts 52 | 53 | 54 | def register_scannet_context(name, dirname, split, class_names=name2class): 55 | DatasetCatalog.register(name, lambda: load_scannet_instances(name, dirname, split, class_names)) 56 | MetadataCatalog.get(name).set( 57 | stuff_classes=class_names[name], 58 | dirname=dirname, 59 | split=split, 60 | ignore_label=[0], 61 | thing_dataset_id_to_contiguous_id={}, 62 | class_offset=1, 63 | keep_sem_bgd=False 64 | ) 65 | 66 | 67 | def register_all_sunrgbd_seg(root): 68 | SPLITS = [ 69 | ("scannet_41_val_seg", "scannet_frames_25k", "val"), 70 | ("scannet_38_val_seg", "scannet_frames_25k", "val"), 71 | ("scannet_21_val_seg", "scannet_frames_25k", "val"), 72 | ] 73 | 74 | for name, dirname, split in SPLITS: 75 | register_scannet_context(name, os.path.join(root, dirname), split) 76 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 77 | 78 | 79 | _root = os.getenv("DATASET", "datasets") 80 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_sunrgbd_semseg.py: -------------------------------------------------------------------------------- 1 | 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | import numpy as np 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.structures import BoxMode 15 | from detectron2.utils.file_io import PathManager 16 | 17 | from utils.constants import SUN_RGBD_37 18 | 19 | __all__ = ["load_sunrgbd_instances", "register_sunrgbd_context"] 20 | 21 | def load_sunrgbd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load SUN-RGBD detection annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | if split == 'val': 31 | split = 'test' 32 | 33 | # Needs to read many small annotation files. Makes sense at local 34 | image_pths = sorted(glob.glob(os.path.join(dirname, 'image', split, '*.jpg'))) 35 | semseg_pths = sorted(glob.glob(os.path.join(dirname, 'label37', split, '*.png'))) 36 | 37 | assert len(image_pths) == len(semseg_pths) 38 | 39 | dicts = [] 40 | for image_dir, semseg_dir in zip(image_pths, semseg_pths): 41 | r = { 42 | "file_name": image_dir, 43 | "sem_seg_file_name": semseg_dir, 44 | "image_id": semseg_dir.split('/')[-1].split('.')[0], 45 | } 46 | dicts.append(r) 47 | return dicts 48 | 49 | 50 | def register_sun_context(name, dirname, split, class_names=SUN_RGBD_37): 51 | DatasetCatalog.register(name, lambda: load_sunrgbd_instances(name, dirname, split, class_names)) 52 | MetadataCatalog.get(name).set( 53 | stuff_classes=class_names, 54 | dirname=dirname, 55 | split=split, 56 | ignore_label=[0], 57 | thing_dataset_id_to_contiguous_id={}, 58 | class_offset=1, 59 | keep_sem_bgd=False 60 | ) 61 | 62 | 63 | def register_all_sunrgbd_seg(root): 64 | SPLITS = [ 65 | ("sunrgbd_37_val_seg", "sun_rgbd", "val"), 66 | ] 67 | 68 | for name, dirname, split in SPLITS: 69 | register_sun_context(name, os.path.join(root, dirname), split) 70 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 71 | 72 | 73 | _root = os.getenv("DATASET", "datasets") 74 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_vlp_coco_entity.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import collections 4 | 5 | from detectron2.data import DatasetCatalog, MetadataCatalog 6 | from detectron2.data.datasets import load_sem_seg 7 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 8 | from detectron2.utils.file_io import PathManager 9 | 10 | 11 | _PREDEFINED_SPLITS_PRETRAIN = { 12 | "vlp_coco_entity_val_long": ( 13 | "coco/val2017", 14 | "coco/annotations/entity_val2017_long.json", 15 | ), 16 | "vlp_coco_entity_val": ( 17 | "coco/val2017", 18 | "coco/annotations/entity_val2017.json", 19 | ), 20 | "vlp_coco_entity_val_retrieval": ( 21 | "coco/val2017", 22 | "coco/annotations/entity_val2017.json", 23 | ), 24 | "vlp_coco_entity_val_retrieval_long": ( 25 | "coco/val2017", 26 | "coco/annotations/entity_val2017_long.json", 27 | ), 28 | } 29 | 30 | evaluator_mapper = {'vlp_coco_entity_val': 'retrieval_interleave_text', 'vlp_coco_entity_val_retrieval': 'retrieval', 'vlp_coco_entity_val_retrieval_long': 'retrieval', 'vlp_coco_entity_val_long': 'retrieval_interleave_text'} 31 | 32 | def get_metadata(name): 33 | return {} 34 | 35 | def load_pretrain_data(image_root, entity_root, meta, name): 36 | """ 37 | Args: 38 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 39 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". 40 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". 41 | Returns: 42 | list[dict]: a list of dicts in Detectron2 standard format. (See 43 | `Using Custom Datasets `_ ) 44 | """ 45 | 46 | with PathManager.open(entity_root) as f: 47 | entity_info = json.load(f) 48 | 49 | # build dictionary for entity 50 | entity_dict = collections.defaultdict(list) 51 | for entity_ann in entity_info['annotations']: 52 | image_id = int(entity_ann["image_id"]) 53 | entity_dict[image_id].append(entity_ann) 54 | 55 | image_dict = collections.defaultdict(list) 56 | for image_ann in entity_info['images']: 57 | image_id = int(image_ann["id"]) 58 | image_dict[image_id] = image_ann['file_name'] 59 | 60 | ret = [] 61 | for image_id in entity_dict.keys(): 62 | file_name = os.path.join(image_root, image_dict[image_id]) 63 | ret.append({ 64 | "file_name": file_name, 65 | "image_id": image_id, 66 | "captions": [entity_dict[image_id][i]['sentence'] for i in range(len(entity_dict[image_id]))], 67 | }) 68 | return ret 69 | 70 | def register_pretrain( 71 | name, metadata, image_root, entity_root, 72 | ): 73 | semantic_name = name 74 | DatasetCatalog.register( 75 | semantic_name, 76 | lambda: load_pretrain_data(image_root, entity_root, metadata, name), 77 | ) 78 | MetadataCatalog.get(semantic_name).set( 79 | evaluator_type=evaluator_mapper[semantic_name], 80 | **metadata, 81 | ) 82 | 83 | def register_all_pretrain(root): 84 | for ( 85 | prefix, 86 | (image_root, entity_root,), 87 | ) in _PREDEFINED_SPLITS_PRETRAIN.items(): 88 | register_pretrain( 89 | prefix, 90 | get_metadata(prefix), 91 | os.path.join(root, image_root), 92 | os.path.join(root, entity_root), 93 | ) 94 | 95 | 96 | _root = os.getenv("DATASET", "datasets") 97 | register_all_pretrain(_root) 98 | -------------------------------------------------------------------------------- /datasets/registration/register_ytvos_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import os 4 | import glob 5 | import json 6 | from typing import List, Tuple, Union 7 | 8 | import cv2 9 | import numpy as np 10 | from scipy.io import loadmat 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | 17 | __all__ = ["load_ytovs_instances", "register_ytvos_context"] 18 | 19 | def load_ytvos_instances(name: str, dirname: str, split: str): 20 | """ 21 | Load Pascal VOC detection annotations to Detectron2 format. 22 | 23 | Args: 24 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 25 | split (str): one of "train", "test", "val", "trainval" 26 | class_names: list or tuple of class names 27 | """ 28 | meta_json = os.path.join(dirname, split, "meta.json") 29 | video_dir = os.path.join(dirname, split, 'JPEGImages') 30 | mask_dir = os.path.join(dirname, split, 'Annotations') 31 | video_names = os.listdir(video_dir) 32 | meta = json.load(open(meta_json))['videos'] 33 | 34 | dicts = [] 35 | for vid_name in video_names: 36 | objects = meta[vid_name]['objects'] 37 | r = { 38 | "file_name": os.path.join(video_dir, vid_name), 39 | "mask_name": os.path.join(mask_dir, vid_name), 40 | "objects": objects, 41 | } 42 | dicts.append(r) 43 | 44 | return dicts 45 | 46 | def register_ytvos_context(name, dirname, split): 47 | DatasetCatalog.register("{}".format(name), lambda: load_ytvos_instances(name, dirname, split)) 48 | MetadataCatalog.get("{}".format(name)).set( 49 | dirname=dirname, 50 | thing_dataset_id_to_contiguous_id={}, 51 | ) 52 | 53 | def register_all_davis(root): 54 | SPLITS = [ 55 | ("ytvos19_val", "ytvos2019", "valid"), 56 | ("ytvos18_val", "ytvos2018", "valid"), 57 | ] 58 | 59 | for name, dirname, split in SPLITS: 60 | register_ytvos_context(name, os.path.join(root, dirname), split) 61 | MetadataCatalog.get("{}".format(name)).evaluator_type = None 62 | 63 | _root = os.getenv("DATASET", "datasets") 64 | register_all_davis(_root) -------------------------------------------------------------------------------- /datasets/semseg_loader.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import scipy.io 3 | import numpy as np 4 | 5 | def load_semseg(filename, loader_type): 6 | if loader_type == 'PIL': 7 | semseg = np.array(Image.open(filename), dtype=np.int) 8 | elif loader_type == 'MAT': 9 | semseg = scipy.io.loadmat(filename)['LabelMap'] 10 | return semseg -------------------------------------------------------------------------------- /datasets/utils/refcoco2json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from refer import REFER 4 | 5 | coco_root = '/pth/to/coco' 6 | ref_root = '/pth/to/refcocoseg' 7 | 8 | coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json'))) 9 | coco_train_id = [] 10 | image_annot = {} 11 | for i in range(len(coco_train_annot['images'])): 12 | coco_train_id.append(coco_train_annot['images'][i]['id']) 13 | image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i] 14 | 15 | refg = REFER(data_root=ref_root, 16 | dataset='refcocog', splitBy='umd') 17 | refg_val_ids = refg.getRefIds(split='val') 18 | 19 | full_anno = [] 20 | for ref_id in refg_val_ids: 21 | ref = refg.loadRefs(ref_id)[0] 22 | anno = refg.refToAnn[ref_id] 23 | anno.update(ref) 24 | full_anno.append(anno) 25 | 26 | imageid_list = [] 27 | final_anno = {} 28 | for anno in full_anno: 29 | imageid_list += [anno['image_id']] 30 | final_anno[anno['ann_id']] = anno 31 | 32 | annotations = [value for key, value in final_anno.items()] 33 | 34 | iamges = [] 35 | for image_id in list(set(imageid_list)): 36 | iamges += [image_annot[image_id]] 37 | 38 | outputs = {'images': iamges, 'annotations': annotations} 39 | print(len(iamges)) 40 | print(len(annotations)) 41 | json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w')) 42 | -------------------------------------------------------------------------------- /datasets/visual_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import ShapeSampler 2 | from .simpleclick_sampler import SimpleClickSampler 3 | 4 | 5 | def build_shape_sampler(cfg, **kwargs): 6 | sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE'] 7 | if sampler_name == 'random': 8 | return ShapeSampler(cfg, **kwargs) 9 | elif sampler_name in ['best', 'best_random']: 10 | return SimpleClickSampler(cfg, **kwargs) 11 | else: 12 | assert False, "not implemented" -------------------------------------------------------------------------------- /datasets/visual_sampler/circle.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from .mask_generators import get_mask_by_input_strokes 5 | 6 | class Circle: 7 | def __init__(self, cfg, is_train=True): 8 | self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES'] 9 | self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET'] 10 | self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB'] 11 | self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] 12 | self.is_train = is_train 13 | 14 | @staticmethod 15 | def get_stroke_preset(stroke_preset): 16 | if stroke_preset == 'object_like': 17 | return { 18 | "nVertexBound": [5, 30], 19 | "maxHeadSpeed": 15, 20 | "maxHeadAcceleration": (10, 1.5), 21 | "brushWidthBound": (20, 50), 22 | "nMovePointRatio": 0.5, 23 | "maxPiontMove": 10, 24 | "maxLineAcceleration": (5, 0.5), 25 | "boarderGap": None, 26 | "maxInitSpeed": 10, 27 | } 28 | elif stroke_preset == 'object_like_middle': 29 | return { 30 | "nVertexBound": [5, 15], 31 | "maxHeadSpeed": 8, 32 | "maxHeadAcceleration": (4, 1.5), 33 | "brushWidthBound": (20, 50), 34 | "nMovePointRatio": 0.5, 35 | "maxPiontMove": 5, 36 | "maxLineAcceleration": (5, 0.5), 37 | "boarderGap": None, 38 | "maxInitSpeed": 10, 39 | } 40 | elif stroke_preset == 'object_like_small': 41 | return { 42 | "nVertexBound": [5, 20], 43 | "maxHeadSpeed": 7, 44 | "maxHeadAcceleration": (3.5, 1.5), 45 | "brushWidthBound": (10, 30), 46 | "nMovePointRatio": 0.5, 47 | "maxPiontMove": 5, 48 | "maxLineAcceleration": (3, 0.5), 49 | "boarderGap": None, 50 | "maxInitSpeed": 4, 51 | } 52 | else: 53 | raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.') 54 | 55 | def get_random_points_from_mask(self, mask, n=5): 56 | h,w = mask.shape 57 | view_mask = mask.reshape(h*w) 58 | non_zero_idx = view_mask.nonzero()[:,0] 59 | selected_idx = torch.randperm(len(non_zero_idx))[:n] 60 | non_zero_idx = non_zero_idx[selected_idx] 61 | y = (non_zero_idx // w)*1.0 62 | x = (non_zero_idx % w)*1.0 63 | return torch.cat((x[:,None], y[:,None]), dim=1).numpy() 64 | 65 | def draw(self, mask=None, box=None): 66 | if mask.sum() < 10: # if mask is nearly empty 67 | return torch.zeros(mask.shape).bool() 68 | if not self.is_train: 69 | return self.draw_eval(mask=mask, box=box) 70 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use 71 | preset = Circle.get_stroke_preset(stroke_preset_name) 72 | nStroke = min(random.randint(1, self.num_stroke), mask.sum().item()) 73 | h,w = mask.shape 74 | points = self.get_random_points_from_mask(mask, n=nStroke) 75 | rand_mask = get_mask_by_input_strokes( 76 | init_points=points, 77 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset) 78 | rand_mask = (~torch.from_numpy(rand_mask)) * mask 79 | return rand_mask 80 | 81 | def draw_eval(self, mask=None, box=None): 82 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use 83 | preset = Circle.get_stroke_preset(stroke_preset_name) 84 | nStroke = min(self.max_eval, mask.sum().item()) 85 | h,w = mask.shape 86 | points = self.get_random_points_from_mask(mask, n=nStroke) 87 | rand_masks = [] 88 | for i in range(len(points)): 89 | rand_mask = get_mask_by_input_strokes( 90 | init_points=points[:i+1], 91 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset) 92 | rand_masks += [(~torch.from_numpy(rand_mask)) * mask] 93 | return torch.stack(rand_masks) 94 | 95 | @staticmethod 96 | def draw_by_points(points, mask, h, w): 97 | stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use 98 | preset = Circle.get_stroke_preset(stroke_preset_name) 99 | rand_mask = get_mask_by_input_strokes( 100 | init_points=points, 101 | imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,] 102 | rand_masks = (~torch.from_numpy(rand_mask)) * mask 103 | return rand_masks 104 | 105 | def __repr__(self,): 106 | return 'circle' -------------------------------------------------------------------------------- /datasets/visual_sampler/point.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy import ndimage 6 | 7 | 8 | class Point: 9 | def __init__(self, cfg, is_train=True): 10 | self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS'] 11 | self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] 12 | self.is_train = is_train 13 | 14 | def draw(self, mask=None, box=None): 15 | if mask.sum() < 10: 16 | return torch.zeros(mask.shape).bool() # if mask is empty 17 | if not self.is_train: 18 | return self.draw_eval(mask=mask, box=box) 19 | max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number 20 | num_points = random.randint(1, max_points) # get a random number of points 21 | h,w = mask.shape 22 | view_mask = mask.view(-1) 23 | non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask 24 | selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id 25 | non_zero_idx = non_zero_idx[selected_idx] # select non-zero index 26 | rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask 27 | rand_mask[non_zero_idx] = True # get non zero place to zero 28 | # dilate 29 | # struct = ndimage.generate_binary_structure(2, 2) 30 | # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype))) 31 | # return rand_mask 32 | return rand_mask.reshape(h, w) 33 | 34 | def draw_eval(self, mask=None, box=None): 35 | background = ~mask 36 | neg_num = min(self.max_eval // 2, background.sum().item()) 37 | pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1 38 | 39 | h,w = mask.shape 40 | view_mask = mask.view(-1) 41 | non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask 42 | selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id 43 | non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index 44 | pos_idx = torch.ones(non_zero_idx_pos.shape) 45 | 46 | view_background = background.view(-1) 47 | non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask 48 | selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id 49 | non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index 50 | neg_idx = torch.ones(non_zero_idx_neg.shape) * -1 51 | 52 | non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg]) 53 | idx = torch.cat([pos_idx, neg_idx]) 54 | rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long() 55 | non_zero_idx = non_zero_idx[rand_idx] 56 | idx = idx[rand_idx] 57 | 58 | rand_masks = [] 59 | for i in range(0, len(non_zero_idx)): 60 | rand_mask = torch.zeros(view_mask.shape) # init rand mask 61 | rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero 62 | # struct = ndimage.generate_binary_structure(2, 2) 63 | # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype))) 64 | rand_masks += [rand_mask.reshape(h, w)] 65 | 66 | # kernel_size = 3 67 | rand_masks = torch.stack(rand_masks) 68 | # rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0] 69 | # rand_masks[rand_masks>0] = 1 70 | # rand_masks[rand_masks<0] = -1 71 | return rand_masks 72 | 73 | def __repr__(self,): 74 | return 'point' -------------------------------------------------------------------------------- /datasets/visual_sampler/sampler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .point import Point 9 | from .polygon import Polygon 10 | from .scribble import Scribble 11 | from .circle import Circle 12 | 13 | from modeling.utils import configurable 14 | 15 | 16 | class ShapeSampler(nn.Module): 17 | @configurable 18 | def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True): 19 | super().__init__() 20 | self.max_candidate = max_candidate 21 | self.shape_prob = shape_prob 22 | self.shape_candidate = shape_candidate 23 | self.is_train = is_train 24 | 25 | @classmethod 26 | def from_config(cls, cfg, is_train=True, mode=None): 27 | max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE'] 28 | candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS'] 29 | candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES'] 30 | 31 | if mode == 'hack_train': 32 | candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names] 33 | else: 34 | # overwrite condidate_prob 35 | if not is_train: 36 | candidate_probs = [0.0 for x in range(len(candidate_names))] 37 | candidate_probs[candidate_names.index(mode)] = 1.0 38 | candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names] 39 | 40 | # Build augmentation 41 | return { 42 | "max_candidate": max_candidate, 43 | "shape_prob": candidate_probs, 44 | "shape_candidate": candidate_classes, 45 | "is_train": is_train, 46 | } 47 | 48 | def forward(self, instances): 49 | masks = instances.gt_masks.tensor 50 | boxes = instances.gt_boxes.tensor 51 | 52 | if len(masks) == 0: 53 | gt_masks = torch.zeros(masks.shape[-2:]).bool() 54 | rand_masks = torch.zeros(masks.shape[-2:]).bool() 55 | return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']} 56 | indices = [x for x in range(len(masks))] 57 | 58 | if self.is_train: 59 | random.shuffle(indices) 60 | max_candidate = max(1, int(time.time()) % self.max_candidate) 61 | candidate_mask = masks[indices[:max_candidate]] 62 | candidate_box = boxes[indices[:max_candidate]] 63 | else: 64 | candidate_mask = masks 65 | candidate_box = boxes 66 | 67 | draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask)) 68 | rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)] 69 | types = [repr(x) for x in draw_funcs] 70 | for i in range(0, len(rand_shapes)): 71 | if rand_shapes[i].sum() == 0: 72 | candidate_mask[i] = candidate_mask[i] * 0 73 | types[i] = 'none' 74 | 75 | # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c) 76 | return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self} 77 | 78 | def build_shape_sampler(cfg, **kwargs): 79 | return ShapeSampler(cfg, **kwargs) -------------------------------------------------------------------------------- /datasets/visual_sampler/scribble.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from .mask_generators import get_mask_by_input_strokes 6 | 7 | class Scribble: 8 | def __init__(self, cfg, is_train): 9 | self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES'] 10 | self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET'] 11 | self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB'] 12 | self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] 13 | self.is_train = is_train 14 | 15 | @staticmethod 16 | def get_stroke_preset(stroke_preset): 17 | if stroke_preset == 'rand_curve': 18 | return { 19 | "nVertexBound": [10, 30], 20 | "maxHeadSpeed": 20, 21 | "maxHeadAcceleration": (15, 0.5), 22 | "brushWidthBound": (3, 10), 23 | "nMovePointRatio": 0.5, 24 | "maxPiontMove": 3, 25 | "maxLineAcceleration": (5, 0.5), 26 | "boarderGap": None, 27 | "maxInitSpeed": 6 28 | } 29 | elif stroke_preset == 'rand_curve_small': 30 | return { 31 | "nVertexBound": [6, 22], 32 | "maxHeadSpeed": 12, 33 | "maxHeadAcceleration": (8, 0.5), 34 | "brushWidthBound": (2.5, 5), 35 | "nMovePointRatio": 0.5, 36 | "maxPiontMove": 1.5, 37 | "maxLineAcceleration": (3, 0.5), 38 | "boarderGap": None, 39 | "maxInitSpeed": 3 40 | } 41 | else: 42 | raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.') 43 | 44 | def get_random_points_from_mask(self, mask, n=5): 45 | h,w = mask.shape 46 | view_mask = mask.reshape(h*w) 47 | non_zero_idx = view_mask.nonzero()[:,0] 48 | selected_idx = torch.randperm(len(non_zero_idx))[:n] 49 | non_zero_idx = non_zero_idx[selected_idx] 50 | y = (non_zero_idx // w)*1.0 51 | x = (non_zero_idx % w)*1.0 52 | return torch.cat((x[:,None], y[:,None]), dim=1).numpy() 53 | 54 | def draw(self, mask=None, box=None): 55 | if mask.sum() < 10: 56 | return torch.zeros(mask.shape).bool() # if mask is empty 57 | if not self.is_train: 58 | return self.draw_eval(mask=mask, box=box) 59 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] 60 | preset = Scribble.get_stroke_preset(stroke_preset_name) 61 | nStroke = random.randint(1, min(self.num_stroke, mask.sum().item())) 62 | h,w = mask.shape 63 | points = self.get_random_points_from_mask(mask, n=nStroke) 64 | rand_mask = get_mask_by_input_strokes( 65 | init_points=points, 66 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset) 67 | rand_mask = (~torch.from_numpy(rand_mask)) * mask 68 | return rand_mask 69 | 70 | def draw_eval(self, mask=None, box=None): 71 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] 72 | preset = Scribble.get_stroke_preset(stroke_preset_name) 73 | nStroke = min(self.eval_stroke, mask.sum().item()) 74 | h,w = mask.shape 75 | points = self.get_random_points_from_mask(mask, n=nStroke) 76 | rand_masks = [] 77 | for i in range(len(points)): 78 | rand_mask = get_mask_by_input_strokes( 79 | init_points=points[:i+1], 80 | imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset) 81 | rand_mask = (~torch.from_numpy(rand_mask)) * mask 82 | rand_masks += [rand_mask] 83 | return torch.stack(rand_masks) 84 | 85 | @staticmethod 86 | def draw_by_points(points, mask, h, w): 87 | stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0] 88 | preset = Scribble.get_stroke_preset(stroke_preset_name) 89 | rand_mask = get_mask_by_input_strokes( 90 | init_points=points, 91 | imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,] 92 | rand_masks = (~torch.from_numpy(rand_mask)) * mask 93 | return rand_masks 94 | 95 | def __repr__(self,): 96 | return 'scribble' -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/demo/__init__.py -------------------------------------------------------------------------------- /demo/find/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/demo/find/__init__.py -------------------------------------------------------------------------------- /demo/find/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/demo/find/arial.ttf -------------------------------------------------------------------------------- /entry.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import torch 11 | import logging 12 | import wandb 13 | 14 | from utils.arguments import load_opt_command 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | logger = logging.getLogger(__name__) 18 | 19 | def init_wandb(args, job_dir, entity='xueyanz', project='xdecoder', job_name='tmp'): 20 | wandb_dir = os.path.join(job_dir, 'wandb') 21 | os.makedirs(wandb_dir, exist_ok=True) 22 | runid = None 23 | if os.path.exists(f"{wandb_dir}/runid.txt"): 24 | runid = open(f"{wandb_dir}/runid.txt").read() 25 | 26 | wandb.init(project=project, 27 | name=job_name, 28 | dir=wandb_dir, 29 | entity=entity, 30 | resume="allow", 31 | id=runid, 32 | config={"hierarchical": True},) 33 | 34 | open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id) 35 | wandb.config.update({k: args[k] for k in args if k not in wandb.config}) 36 | 37 | def main(args=None): 38 | ''' 39 | [Main function for the entry point] 40 | 1. Set environment variables for distributed training. 41 | 2. Load the config file and set up the trainer. 42 | ''' 43 | 44 | opt, cmdline_args = load_opt_command(args) 45 | command = cmdline_args.command 46 | 47 | if cmdline_args.user_dir: 48 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 49 | opt['base_path'] = absolute_user_dir 50 | 51 | # update_opt(opt, command) 52 | world_size = 1 53 | if 'OMPI_COMM_WORLD_SIZE' in os.environ: 54 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 55 | 56 | if opt['TRAINER'] == 'xdecoder': 57 | from trainer import XDecoder_Trainer as Trainer 58 | else: 59 | assert False, "The trainer type: {} is not defined!".format(opt['TRAINER']) 60 | 61 | trainer = Trainer(opt) 62 | os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL' 63 | 64 | if command == "train": 65 | if opt['rank'] == 0 and opt['WANDB']: 66 | wandb.login(key=os.environ['WANDB_KEY']) 67 | init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder) 68 | trainer.train() 69 | elif command == "evaluate": 70 | trainer.eval() 71 | else: 72 | raise ValueError(f"Unknown command: {command}") 73 | 74 | if __name__ == "__main__": 75 | main() 76 | sys.exit(0) 77 | -------------------------------------------------------------------------------- /modeling/BaseModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from utils.model import align_and_update_state_dicts 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BaseModel(nn.Module): 13 | def __init__(self, opt, module: nn.Module): 14 | super(BaseModel, self).__init__() 15 | self.opt = opt 16 | self.model = module 17 | 18 | def forward(self, *inputs, **kwargs): 19 | outputs = self.model(*inputs, **kwargs) 20 | return outputs 21 | 22 | def save_pretrained(self, save_dir): 23 | torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt")) 24 | 25 | def from_pretrained(self, load_dir): 26 | state_dict = torch.load(load_dir, map_location=self.opt['device']) 27 | state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) 28 | self.model.load_state_dict(state_dict, strict=False) 29 | return self -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import build_model -------------------------------------------------------------------------------- /modeling/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder_model import * 2 | from .seem_model_v0 import * 3 | from .seem_model_v1 import * 4 | from .find_model import * 5 | from .build import build_model -------------------------------------------------------------------------------- /modeling/architectures/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def build_model(config, **kwargs): 5 | model_name = config['MODEL']['NAME'] 6 | 7 | if not is_model(model_name): 8 | raise ValueError(f'Unkown model: {model_name}') 9 | 10 | return model_entrypoints(model_name)(config, **kwargs) 11 | 12 | def register_model(fn): 13 | module_name_split = fn.__module__.split('.') 14 | model_name = module_name_split[-1] 15 | _model_entrypoints[model_name] = fn 16 | return fn 17 | 18 | def model_entrypoints(model_name): 19 | return _model_entrypoints[model_name] 20 | 21 | def is_model(model_name): 22 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/body/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder_head import * 2 | from .build import * 3 | 4 | def build_xdecoder_head(config, *args, **kwargs): 5 | model_name = config['MODEL']['HEAD'] 6 | if not is_model(model_name): 7 | raise ValueError(f'Unkown model: {model_name}') 8 | 9 | body = model_entrypoints(model_name)(config, *args, **kwargs) 10 | return body -------------------------------------------------------------------------------- /modeling/body/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_body(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/interface/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder import * 2 | from .seem_v0 import * 3 | from .seem_v1 import * 4 | from .find import * 5 | from .build import * 6 | 7 | def build_decoder(config, *args, **kwargs): 8 | model_name = config['MODEL']['DECODER']['NAME'] 9 | 10 | if not is_model(model_name): 11 | raise ValueError(f'Unkown model: {model_name}') 12 | 13 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /modeling/interface/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_decoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/interface/operator/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .modules import * -------------------------------------------------------------------------------- /modeling/interface/operator/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP 5 | 6 | 7 | class ContentAttention(nn.Module): 8 | def __init__(self, num_layers, hidden_dim, nheads, pre_norm, dropout=0.0): 9 | super().__init__() 10 | 11 | self.layers = nn.ModuleList() 12 | for _ in range(num_layers): 13 | self.layers.append( 14 | CrossAttentionLayer( 15 | d_model=hidden_dim, 16 | nhead=nheads, 17 | dropout=0.0, 18 | normalize_before=pre_norm, 19 | ) 20 | ) 21 | 22 | def forward(self, layer_id, content_variables): 23 | outputs = [] 24 | for key, value in content_variables.items(): 25 | output, _ = self.layers[layer_id]( 26 | value['output'], value['src'], 27 | memory_mask=value['memory_mask'], 28 | memory_key_padding_mask=None if 'memory_key_padding_mask' not in value else value['memory_key_padding_mask'], # here we do not apply masking on padded region 29 | pos=value['pos'], query_pos=value['query_pos'] 30 | ) 31 | outputs += [output] 32 | return torch.cat(outputs) 33 | 34 | class ConditionAttention(nn.Module): 35 | def __init__(self, num_layers, hidden_dim, nheads, pre_norm, dropout=0.0): 36 | super().__init__() 37 | 38 | self.layers = nn.ModuleList() 39 | for _ in range(num_layers): 40 | self.layers.append( 41 | SelfAttentionLayer( 42 | d_model=hidden_dim, 43 | nhead=nheads, 44 | dropout=0.0, 45 | normalize_before=pre_norm, 46 | ) 47 | ) 48 | 49 | def forward(self, layer_id, output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): 50 | output = self.layers[layer_id]( 51 | output, tgt_mask=tgt_mask, 52 | tgt_key_padding_mask=tgt_key_padding_mask, 53 | query_pos=query_pos 54 | ) 55 | return output -------------------------------------------------------------------------------- /modeling/interface/prototype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/modeling/interface/prototype/__init__.py -------------------------------------------------------------------------------- /modeling/language/LangEncoder/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTokenizer, CLIPTokenizerFast 2 | from transformers import AutoTokenizer 3 | 4 | from .transformer import * 5 | from .build import * 6 | from .modeling_llama import * 7 | from ..Tokenizer.custom_tokenizer import * 8 | 9 | 10 | def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): 11 | model_name = config_encoder['NAME'] 12 | 13 | if not is_lang_encoder(model_name): 14 | raise ValueError(f'Unkown model: {model_name}') 15 | 16 | return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) 17 | 18 | def build_tokenizer(config_encoder): 19 | tokenizer = None 20 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 21 | if config_encoder['TOKENIZER'] == 'clip': 22 | pretrained_tokenizer = config_encoder.get( 23 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 24 | ) 25 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) 26 | tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) 27 | elif config_encoder['TOKENIZER'] == 'clip-fast': 28 | pretrained_tokenizer = config_encoder.get( 29 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 30 | ) 31 | tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) 32 | elif config_encoder['TOKENIZER'] == 'clip-token': 33 | pretrained_tokenizer = config_encoder.get( 34 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 35 | ) 36 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) 37 | tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) 38 | tokenizer = CustomCLIPTokenizer(tokenizer, config_encoder['CONTEXT_LENGTH']) 39 | else: 40 | tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) 41 | 42 | return tokenizer -------------------------------------------------------------------------------- /modeling/language/LangEncoder/build.py: -------------------------------------------------------------------------------- 1 | _lang_encoders = {} 2 | 3 | 4 | def register_lang_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | 8 | _lang_encoders[model_name] = fn 9 | 10 | return fn 11 | 12 | def lang_encoders(model_name): 13 | return _lang_encoders[model_name] 14 | 15 | def is_lang_encoder(model_name): 16 | return model_name in _lang_encoders 17 | -------------------------------------------------------------------------------- /modeling/language/Tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/modeling/language/Tokenizer/__init__.py -------------------------------------------------------------------------------- /modeling/language/Tokenizer/custom_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def split_by_ordered_substrings(sentence, substrings): 6 | results = [] 7 | substring_indices = [] 8 | 9 | start_index = 0 10 | for i, substring in enumerate(substrings): 11 | # Find the start of the substring in the remaining part of the sentence 12 | index = sentence[start_index:].find(substring) 13 | 14 | if index == -1: 15 | continue 16 | 17 | # Append any text before the substring to the results, including spaces 18 | if index > 0: 19 | results.append(sentence[start_index:start_index+index]) 20 | substring_indices.append(None) # No match in the `substrings` list for this segment 21 | 22 | # Append the substring to the results 23 | results.append(substring) 24 | substring_indices.append(i) # Append the index from the `substrings` list 25 | start_index += index + len(substring) 26 | 27 | # If there's any remaining part of the sentence after all substrings, append it to the results 28 | if start_index < len(sentence): 29 | results.append(sentence[start_index:]) 30 | substring_indices.append(None) # No match in the `substrings` list for this segment 31 | 32 | return results, substring_indices 33 | 34 | class CustomCLIPTokenizer(nn.Module): 35 | def __init__(self, tokenizer, max_token_num): 36 | super().__init__() 37 | self.tokenizer = tokenizer 38 | self.max_token_num = max_token_num 39 | 40 | def forward(self, sentence, entity_list): 41 | substrings, index = split_by_ordered_substrings(sentence, entity_list) 42 | tokens = self.tokenizer( 43 | substrings, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' 44 | ) 45 | 46 | _input_ids, _attention_mask = [], [] 47 | start_end_condition = torch.zeros(len(entity_list), len(tokens['input_ids'][0])).bool() 48 | 49 | start_idx = 0 50 | end_idx = 0 51 | for j, (input_id, attn_mask, idx) in enumerate(zip(tokens['input_ids'], tokens['attention_mask'], index)): 52 | if j == 0: 53 | _input_ids += [input_id[attn_mask.bool()][:-1]] 54 | _attention_mask += [attn_mask[attn_mask!=0][:-1]] 55 | elif j == len(tokens['input_ids']) - 1: 56 | _input_ids += [input_id[attn_mask.bool()][1:]] 57 | _attention_mask += [attn_mask[attn_mask!=0][1:]] 58 | else: 59 | _input_ids += [input_id[attn_mask.bool()][1:-1]] 60 | _attention_mask += [attn_mask[attn_mask!=0][1:-1]] 61 | 62 | end_idx += len(_input_ids[-1]) 63 | 64 | if idx is not None: 65 | start_end_condition[idx,start_idx:end_idx] = True 66 | start_idx = end_idx 67 | 68 | _input_ids_all = torch.ones(self.max_token_num, dtype=torch.long) * self.tokenizer.pad_token_id 69 | _attention_mask_all = torch.zeros(self.max_token_num, dtype=torch.long) 70 | 71 | _input_ids = torch.cat(_input_ids)[:self.max_token_num] 72 | _attention_mask = torch.cat(_attention_mask)[:self.max_token_num] 73 | 74 | _input_ids_all[:len(_input_ids)] = _input_ids 75 | _attention_mask_all[:len(_attention_mask)] = _attention_mask 76 | if len(_attention_mask) < self.max_token_num: 77 | _attention_mask_all[len(_attention_mask)] = 1 78 | 79 | tokens = {"input_ids": _input_ids_all[None,], "attention_mask": _attention_mask_all[None,]} 80 | return tokens, start_end_condition 81 | 82 | -------------------------------------------------------------------------------- /modeling/language/__init__.py: -------------------------------------------------------------------------------- 1 | from .vlpencoder import * 2 | from .llamaencoder import * 3 | from .vlpencoder_v1 import * 4 | from .build import * 5 | 6 | def build_language_encoder(config, **kwargs): 7 | model_name = config['MODEL']['TEXT']['ARCH'] 8 | 9 | if not is_model(model_name): 10 | raise ValueError(f'Unkown model: {model_name}') 11 | 12 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /modeling/language/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_model(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/language/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import nltk 5 | nltk.data.path.append('/mnt/data/nltk_data') 6 | import numpy as np 7 | 8 | from utils.constants import IMAGENET_DEFAULT_TEMPLATES 9 | 10 | 11 | def get_tag(tokenized, tags): 12 | if not isinstance(tags, (list, tuple)): 13 | tags = [tags] 14 | ret = [] 15 | for (word, pos) in nltk.pos_tag(tokenized): 16 | for tag in tags: 17 | if pos == tag: 18 | ret.append(word) 19 | return ret 20 | 21 | def get_noun_phrase(tokenized): 22 | # Taken from Su Nam Kim Paper... 23 | grammar = r""" 24 | NBAR: 25 | {*} # Nouns and Adjectives, terminated with Nouns 26 | 27 | NP: 28 | {} 29 | {} # Above, connected with in/of/etc... 30 | """ 31 | chunker = nltk.RegexpParser(grammar) 32 | 33 | chunked = chunker.parse(nltk.pos_tag(tokenized)) 34 | continuous_chunk = [] 35 | current_chunk = [] 36 | 37 | for subtree in chunked: 38 | if isinstance(subtree, nltk.Tree): 39 | current_chunk.append(' '.join([token for token, pos in subtree.leaves()])) 40 | elif current_chunk: 41 | named_entity = ' '.join(current_chunk) 42 | if named_entity not in continuous_chunk: 43 | continuous_chunk.append(named_entity) 44 | current_chunk = [] 45 | else: 46 | continue 47 | 48 | return continuous_chunk 49 | 50 | def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True): 51 | tokenized = nltk.word_tokenize(text) 52 | 53 | if random.random() >= phrase_prob: 54 | nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP']) 55 | else: 56 | nouns = get_noun_phrase(tokenized) 57 | 58 | 59 | prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns] 60 | 61 | if append_text: 62 | prompt_texts += [text] 63 | nouns += [text] 64 | 65 | return prompt_texts, nouns -------------------------------------------------------------------------------- /modeling/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .point_features import * 2 | from .position_encoding import * 3 | from .postprocessing import * 4 | from .attention import * 5 | from .criterion import * 6 | from .matcher import * -------------------------------------------------------------------------------- /modeling/modules/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=x.dtype) 34 | x_embed = not_mask.cumsum(2, dtype=x.dtype) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | def __repr__(self, _repr_indent=4): 55 | head = "Positional encoding " + self.__class__.__name__ 56 | body = [ 57 | "num_pos_feats: {}".format(self.num_pos_feats), 58 | "temperature: {}".format(self.temperature), 59 | "normalize: {}".format(self.normalize), 60 | "scale: {}".format(self.scale), 61 | ] 62 | # _repr_indent = 4 63 | lines = [head] + [" " * _repr_indent + line for line in body] 64 | return "\n".join(lines) 65 | -------------------------------------------------------------------------------- /modeling/modules/postprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from detectron2.structures import Instances, ROIMasks 6 | 7 | 8 | # perhaps should rename to "resize_instance" 9 | def detector_postprocess( 10 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 11 | ): 12 | """ 13 | Resize the output instances. 14 | The input images are often resized when entering an object detector. 15 | As a result, we often need the outputs of the detector in a different 16 | resolution from its inputs. 17 | 18 | This function will resize the raw outputs of an R-CNN detector 19 | to produce outputs according to the desired output resolution. 20 | 21 | Args: 22 | results (Instances): the raw outputs from the detector. 23 | `results.image_size` contains the input image resolution the detector sees. 24 | This object might be modified in-place. 25 | output_height, output_width: the desired output resolution. 26 | 27 | Returns: 28 | Instances: the resized output from the model, based on the output resolution 29 | """ 30 | if isinstance(output_width, torch.Tensor): 31 | # This shape might (but not necessarily) be tensors during tracing. 32 | # Converts integer tensors to float temporaries to ensure true 33 | # division is performed when computing scale_x and scale_y. 34 | output_width_tmp = output_width.float() 35 | output_height_tmp = output_height.float() 36 | new_size = torch.stack([output_height, output_width]) 37 | else: 38 | new_size = (output_height, output_width) 39 | output_width_tmp = output_width 40 | output_height_tmp = output_height 41 | 42 | scale_x, scale_y = ( 43 | output_width_tmp / results.image_size[1], 44 | output_height_tmp / results.image_size[0], 45 | ) 46 | results = Instances(new_size, **results.get_fields()) 47 | 48 | if results.has("pred_boxes"): 49 | output_boxes = results.pred_boxes 50 | elif results.has("proposal_boxes"): 51 | output_boxes = results.proposal_boxes 52 | else: 53 | output_boxes = None 54 | assert output_boxes is not None, "Predictions must contain boxes!" 55 | 56 | output_boxes.scale(scale_x, scale_y) 57 | output_boxes.clip(results.image_size) 58 | 59 | results = results[output_boxes.nonempty()] 60 | 61 | if results.has("pred_masks"): 62 | if isinstance(results.pred_masks, ROIMasks): 63 | roi_masks = results.pred_masks 64 | else: 65 | # pred_masks is a tensor of shape (N, 1, M, M) 66 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) 67 | results.pred_masks = roi_masks.to_bitmasks( 68 | results.pred_boxes, output_height, output_width, mask_threshold 69 | ).tensor # TODO return ROIMasks/BitMask object in the future 70 | 71 | if results.has("pred_keypoints"): 72 | results.pred_keypoints[:, :, 0] *= scale_x 73 | results.pred_keypoints[:, :, 1] *= scale_y 74 | 75 | return results 76 | 77 | def bbox_postprocess(result, input_size, img_size, output_height, output_width): 78 | """ 79 | result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] 80 | """ 81 | if result is None: 82 | return None 83 | 84 | scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) 85 | result = result.sigmoid() * scale 86 | x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 87 | h,w = img_size 88 | 89 | x1 = x1.clamp(min=0, max=w) 90 | y1 = y1.clamp(min=0, max=h) 91 | x2 = x2.clamp(min=0, max=w) 92 | y2 = y2.clamp(min=0, max=h) 93 | 94 | box = torch.stack([x1,y1,x2,y2]).permute(1,0) 95 | scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) 96 | box = box*scale 97 | return box 98 | 99 | def sem_seg_postprocess(result, img_size, output_height, output_width): 100 | """ 101 | Return semantic segmentation predictions in the original resolution. 102 | 103 | The input images are often resized when entering semantic segmentor. Moreover, in same 104 | cases, they also padded inside segmentor to be divisible by maximum network stride. 105 | As a result, we often need the predictions of the segmentor in a different 106 | resolution from its inputs. 107 | 108 | Args: 109 | result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), 110 | where C is the number of classes, and H, W are the height and width of the prediction. 111 | img_size (tuple): image size that segmentor is taking as input. 112 | output_height, output_width: the desired output resolution. 113 | 114 | Returns: 115 | semantic segmentation prediction (Tensor): A tensor of the shape 116 | (C, output_height, output_width) that contains per-pixel soft predictions. 117 | """ 118 | result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) 119 | result = F.interpolate( 120 | result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True 121 | )[0] 122 | return result 123 | -------------------------------------------------------------------------------- /modeling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .misc import * 3 | from .interactive import * 4 | from .attention import * -------------------------------------------------------------------------------- /modeling/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | def box_xywh_to_xyxy(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [x0, y0, (x0 + x1), (y0 + y1)] 25 | return torch.stack(b, dim=-1) 26 | 27 | 28 | # modified from torchvision to also return the union 29 | def box_iou(boxes1, boxes2): 30 | area1 = box_area(boxes1) 31 | area2 = box_area(boxes2) 32 | 33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clamp(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | iou = inter / (union+1e-6) 42 | return iou, union 43 | 44 | 45 | def generalized_box_iou(boxes1, boxes2): 46 | """ 47 | Generalized IoU from https://giou.stanford.edu/ 48 | 49 | The boxes should be in [x0, y0, x1, y1] format 50 | 51 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 52 | and M = len(boxes2) 53 | """ 54 | # degenerate boxes gives inf / nan results 55 | # so do an early check 56 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 57 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 58 | iou, union = box_iou(boxes1, boxes2) 59 | 60 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 61 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 62 | 63 | wh = (rb - lt).clamp(min=0) # [N,M,2] 64 | area = wh[:, :, 0] * wh[:, :, 1] 65 | 66 | return iou - (area - union) / (area+1e-6) 67 | 68 | 69 | def masks_to_boxes(masks): 70 | """Compute the bounding boxes around the provided masks 71 | 72 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 73 | 74 | Returns a [N, 4] tensors, with the boxes in xyxy format 75 | """ 76 | if masks.numel() == 0: 77 | return torch.zeros((0, 4), device=masks.device) 78 | 79 | h, w = masks.shape[-2:] 80 | 81 | y = torch.arange(0, h, dtype=torch.float) 82 | x = torch.arange(0, w, dtype=torch.float) 83 | y, x = torch.meshgrid(y, x) 84 | 85 | x_mask = (masks * x.unsqueeze(0)) 86 | x_max = x_mask.flatten(1).max(-1)[0] 87 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 88 | 89 | y_mask = (masks * y.unsqueeze(0)) 90 | y_max = y_mask.flatten(1).max(-1)[0] 91 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 92 | 93 | return torch.stack([x_min, y_min, x_max, y_max], 1) -------------------------------------------------------------------------------- /modeling/utils/interactive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import math 4 | 5 | import torch 6 | from torch import nn, Tensor 7 | import torch.nn.functional as F 8 | 9 | 10 | def rand_sample(x, divisor, max_len): 11 | # non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']] 12 | if len(x.nonzero()) == 0: 13 | return x.nonzero().t() 14 | 15 | non_zero_point_index = (x.nonzero()/divisor).t() 16 | mask_ids = non_zero_point_index[0].unique().long() 17 | 18 | # compute probability for each sample 19 | probs = torch.zeros_like(non_zero_point_index[0]) 20 | for idx in mask_ids: 21 | prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum())) 22 | probs[non_zero_point_index[0]==idx] = prob 23 | 24 | indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0] 25 | non_zero_point_index = non_zero_point_index[:,indices] 26 | return non_zero_point_index # [n, 512] 27 | 28 | def rand_sample_plain(x, max_len): 29 | if x.shape[1] <= max_len: 30 | return x 31 | else: 32 | rand_idx = torch.randperm(x.shape[1])[:max_len] 33 | return x[:,rand_idx] 34 | 35 | def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed): 36 | src = [] 37 | pos = [] 38 | size_list = [] 39 | 40 | # disable mask, it does not affect performance 41 | for i in range(num_feature_levels): 42 | size_list.append(x[i].shape[-2:]) 43 | pos.append(pe_layer(x[i], None).flatten(2)) 44 | src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None]) 45 | 46 | # flatten NxCxHxW to HWxNxC 47 | pos[-1] = pos[-1].permute(2, 0, 1) 48 | src[-1] = src[-1].permute(2, 0, 1) 49 | return src, pos, size_list -------------------------------------------------------------------------------- /modeling/vision/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .focal import * 2 | from .focal_dw import * 3 | from .davit import * 4 | from .vit import * 5 | from .backbone import * 6 | from .build import * 7 | 8 | 9 | def build_backbone(config, **kwargs): 10 | model_name = config['MODEL']['BACKBONE']['NAME'] 11 | if not is_model(model_name): 12 | raise ValueError(f'Unkown model: {model_name}') 13 | 14 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /modeling/vision/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | 4 | from detectron2.modeling import ShapeSpec 5 | 6 | # from ..layers import ShapeSpec 7 | 8 | __all__ = ["Backbone"] 9 | 10 | 11 | class Backbone(nn.Module): 12 | """ 13 | Abstract base class for network backbones. 14 | """ 15 | 16 | def __init__(self): 17 | """ 18 | The `__init__` method of any subclass can specify its own set of arguments. 19 | """ 20 | super().__init__() 21 | 22 | def forward(self): 23 | """ 24 | Subclasses must override this method, but adhere to the same return type. 25 | 26 | Returns: 27 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor 28 | """ 29 | pass 30 | 31 | @property 32 | def size_divisibility(self) -> int: 33 | """ 34 | Some backbones require the input height and width to be divisible by a 35 | specific integer. This is typically true for encoder / decoder type networks 36 | with lateral connection (e.g., FPN) for which feature maps need to match 37 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific 38 | input size divisibility is required. 39 | """ 40 | return 0 41 | 42 | def output_shape(self): 43 | """ 44 | Returns: 45 | dict[str->ShapeSpec] 46 | """ 47 | # this is a backward-compatible default 48 | return { 49 | name: ShapeSpec( 50 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 51 | ) 52 | for name in self._out_features 53 | } 54 | -------------------------------------------------------------------------------- /modeling/vision/backbone/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_backbone(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/vision/backbone/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /modeling/vision/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_encoder_fpn import * 2 | from .transformer_encoder_deform import * 3 | from .build import * 4 | 5 | 6 | def build_encoder(config, *args, **kwargs): 7 | model_name = config['MODEL']['ENCODER']['NAME'] 8 | 9 | if not is_model(model_name): 10 | raise ValueError(f'Unkown model: {model_name}') 11 | 12 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /modeling/vision/encoder/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd modeling/vision/encoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/pipeline/__init__.py -------------------------------------------------------------------------------- /pipeline/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | def hook_opt(opt): 7 | logger.warning('Need to make the name of SEEM and FIND compatible.') 8 | grounding_flag, spatial_flag = False, False 9 | if 'seem_model' in opt['MODEL']['NAME']: 10 | grounding_flag = opt['REF']['INPUT']['SPATIAL'] 11 | spatial_flag = opt['STROKE_SAMPLER']['EVAL']['GROUNDING'] 12 | 13 | if grounding_flag: 14 | opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['grounding'] = ['queries_grounding', 'tokens_grounding', 'tokens_spatial'] 15 | if spatial_flag: 16 | opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['spatial'] = ['queries_spatial', 'tokens_spatial', 'memories_spatial', 'tokens_grounding'] 17 | 18 | return opt 19 | 20 | # HACK for evalution 21 | def hook_metadata(metadata, name): 22 | return metadata 23 | 24 | # HACK for evalution 25 | def hook_switcher(model, name): 26 | mappings = {} 27 | if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'context_59_val_seg', 'context_459_val_seg', 'voc_2012_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']: 28 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False} 29 | elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name: 30 | mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False} 31 | elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']: 32 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True} 33 | elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']: 34 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True} 35 | # else: 36 | # pass 37 | # if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd", "pascalvoc_val_Point", "grounding_coco_entity_val", "grounding_coco_entity_val_long", "vlp_coco_entity_val", "vlp_coco_interleave_val", "vlp_coco_entity_val_long", "vlp_coco_entity_val_retrieval", 'vlp_coco_entity_val_retrieval_long', "vlp_coco_interleave_val_long"]: 38 | # assert False, "dataset switcher is not defined" 39 | 40 | for key, value in mappings.items(): 41 | if key == 'SEMANTIC_ON': 42 | model.model.semantic_on = value 43 | if key == 'INSTANCE_ON': 44 | model.model.instance_on = value 45 | if key == 'PANOPTIC_ON': 46 | model.model.panoptic_on = value -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder_trainer import * -------------------------------------------------------------------------------- /trainer/distributed_trainer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import logging 10 | from mpi4py import MPI 11 | 12 | import torch 13 | 14 | from .utils.hook import add_hook 15 | from .utils.mpi_adapter import MPIAdapter 16 | from .utils.misc import save_opt_to_yaml 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class DistributedTrainer: 22 | def __init__(self, opt): 23 | self.opt = opt 24 | 25 | # parse environment information for distributed training 26 | adapter = MPIAdapter(self.opt['PORT']) 27 | self.opt['world_size'] = adapter.world_size 28 | self.opt['local_size'] = adapter.local_size 29 | self.opt['rank'] = adapter.rank 30 | self.opt['local_rank'] = adapter.local_rank 31 | 32 | self.set_opt_hook() 33 | 34 | # set up device 35 | if not self.opt['CUDA']: 36 | self.opt['device'] = torch.device("cpu") 37 | logger.info("Using CPU") 38 | else: 39 | torch.cuda.set_device(self.opt['local_rank']) 40 | self.opt['device'] = torch.device("cuda", self.opt['local_rank']) 41 | logger.info("Using CUDA") 42 | 43 | # init distributed training 44 | adapter.log_info() 45 | if torch.distributed.is_available() and self.opt['world_size'] > 1: 46 | adapter.init_process_group(backend='nccl') 47 | 48 | # save config file 49 | self.save_folder = self.opt['SAVE_DIR'] 50 | 51 | if self.opt['world_size'] > 1: 52 | torch.distributed.barrier() 53 | 54 | if self.opt['rank'] == 0: 55 | os.makedirs(self.save_folder, exist_ok=True) 56 | 57 | logger.info(f"Save config file to {os.path.join(self.save_folder, 'conf_copy.yaml')}") 58 | save_opt_to_yaml(self.opt, os.path.join(self.save_folder, 'conf_copy.yaml')) 59 | 60 | # ddp: log stats and update learning rate 61 | self.grad_acc_steps = self.opt['GRADIENT_ACCUMULATE_STEP'] 62 | logger.info(f"Base learning rate: {self.opt['SOLVER']['BASE_LR']}") 63 | logger.info(f"Number of GPUs: {self.opt['world_size']}") 64 | logger.info(f"Gradient accumulation steps: {self.grad_acc_steps}") 65 | 66 | if self.opt['world_size'] > 1: 67 | add_hook() 68 | 69 | # prepare metadata for save folder 70 | conf_file = self.opt['conf_files'][0] 71 | if 'BASENAME' not in self.opt: 72 | self.opt['BASENAME'] = os.path.basename(conf_file) 73 | 74 | self.init_save_folder() 75 | 76 | def set_opt_hook(self): 77 | # Fill in the default values for required keywords 78 | self.opt['CUDA'] = self.opt.get('CUDA', True) and torch.cuda.is_available() 79 | self.opt['FP16'] = self.opt.get('FP16', False) and self.opt['CUDA'] 80 | self.opt['GRADIENT_ACCUMULATE_STEP'] = int(self.opt.get('GRADIENT_ACCUMULATE_STEP', 1)) 81 | self.opt['EVAL_PER_UPDATE_NUM'] = int(self.opt.get('EVAL_PER_UPDATE_NUM', 0)) 82 | self.opt['LR_SCHEDULER_PARAMS'] = self.opt.get('LR_SCHEDULER_PARAMS', {}) 83 | 84 | if 'SAVE_DIR' not in self.opt: 85 | assert False, "Please initialize SAVE_DIR in your config file." 86 | self.opt['SAVE_DIR'] = os.path.normpath(self.opt['SAVE_DIR']) 87 | logger.info(f"Setting SAVE_DIR as {self.opt['SAVE_DIR']}") 88 | 89 | def init_save_folder(self): 90 | """ 91 | Initialize the save folder for logs, model, checkpoint, and evaluation. 92 | """ 93 | runid = 1 94 | 95 | if self.opt['world_size'] > 1: 96 | torch.distributed.barrier() 97 | 98 | if self.opt['rank'] == 0: 99 | while True: 100 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}") 101 | try: 102 | os.makedirs(save_folder, exist_ok=False) 103 | break 104 | except FileExistsError: 105 | runid = runid + 1 106 | 107 | if self.opt['world_size'] > 1: 108 | torch.distributed.barrier() 109 | 110 | if self.opt['world_size'] > 1: 111 | runid = 1 112 | while True: 113 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}") 114 | if not os.path.exists(save_folder): 115 | break 116 | else: 117 | runid += 1 118 | 119 | runid -= 1 120 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}") 121 | # this second os.makedirs() call on all ranks is to force sync the save_folder creation between blobFuse and local fs 122 | os.makedirs(save_folder, exist_ok=True) 123 | 124 | self.save_folder = save_folder -------------------------------------------------------------------------------- /trainer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/trainer/utils/__init__.py -------------------------------------------------------------------------------- /trainer/utils/hook.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | _orig_except_hook = None 7 | 8 | 9 | def _global_except_hook(exctype, value, traceback): 10 | """Catches an unhandled exception and call MPI_Abort().""" 11 | try: 12 | if _orig_except_hook: 13 | _orig_except_hook(exctype, value, traceback) 14 | else: 15 | sys.__excepthook__(exctype, value, traceback) 16 | 17 | finally: 18 | import mpi4py.MPI 19 | rank = mpi4py.MPI.COMM_WORLD.Get_rank() 20 | logger.warning("******************************************") 21 | logger.warning("DefaultTrainer:") 22 | logger.warning(f" Uncaught exception on rank {rank}.") 23 | logger.warning(" Calling MPI_Abort() to shut down MPI...") 24 | logger.warning("******************************************") 25 | logging.shutdown() 26 | 27 | try: 28 | import mpi4py.MPI 29 | mpi4py.MPI.COMM_WORLD.Abort(1) 30 | except Exception as e: 31 | # Something is completely broken... 32 | # There's nothing we can do any more 33 | sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n") 34 | sys.stderr.flush() 35 | raise e 36 | 37 | 38 | def add_hook(): 39 | """ 40 | Add a global hook function that captures all unhandled exceptions. 41 | The function calls MPI_Abort() to force all processes abort. 42 | 43 | An MPI runtime is expected to kill all of its child processes 44 | if one of them exits abnormally or without calling `MPI_Finalize()`. 45 | However, when a Python program run on `mpi4py`, the MPI runtime 46 | often fails to detect a process failure, and the rest of the processes 47 | hang infinitely. 48 | 49 | See https://github.com/chainer/chainermn/issues/236 and 50 | https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more 51 | information. 52 | """ 53 | global _orig_except_hook 54 | 55 | if _orig_except_hook is not None: 56 | logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.") 57 | return 58 | 59 | logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.") 60 | _orig_except_hook = sys.excepthook 61 | sys.excepthook = _global_except_hook 62 | -------------------------------------------------------------------------------- /trainer/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from typing import Dict 4 | 5 | 6 | class JSONEncoder(json.JSONEncoder): 7 | def default(self, obj): 8 | if isinstance(obj, np.integer): 9 | return int(obj) 10 | elif isinstance(obj, np.floating): 11 | return float(obj) 12 | elif isinstance(obj, np.ndarray): 13 | return obj.tolist() 14 | else: 15 | return super(JSONEncoder, self).default(obj) 16 | 17 | 18 | def is_jsonable(x, json_encoder=None): 19 | try: 20 | json.dumps(x, cls=json_encoder) 21 | return True 22 | except Exception: 23 | return False 24 | 25 | 26 | def filter_jsonable(data: Dict, json_encoder=None) -> Dict: 27 | return {k: v for k, v in data.items() if is_jsonable(k, json_encoder=json_encoder) and is_jsonable(v, json_encoder=json_encoder)} -------------------------------------------------------------------------------- /utils/Config.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode as _CfgNode 2 | 3 | 4 | class CfgNode(_CfgNode): 5 | """ 6 | The same as `fvcore.common.config.CfgNode`, but different in: 7 | 8 | 1. Use unsafe yaml loading by default. 9 | Note that this may lead to arbitrary code execution: you must not 10 | load a config file from untrusted sources before manually inspecting 11 | the content of the file. 12 | 2. Support config versioning. 13 | When attempting to merge an old config, it will convert the old config automatically. 14 | 15 | .. automethod:: clone 16 | .. automethod:: freeze 17 | .. automethod:: defrost 18 | .. automethod:: is_frozen 19 | .. automethod:: load_yaml_with_base 20 | .. automethod:: merge_from_list 21 | .. automethod:: merge_from_other_cfg 22 | """ 23 | 24 | def merge_from_dict(self, dict): 25 | pass 26 | 27 | node = CfgNode() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_engineering import * 2 | from .dataset import * -------------------------------------------------------------------------------- /utils/arguments.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import argparse 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def load_config_dict_to_opt(opt, config_dict): 10 | """ 11 | Load the key, value pairs from config_dict to opt, overriding existing values in opt 12 | if there is any. 13 | """ 14 | if not isinstance(config_dict, dict): 15 | raise TypeError("Config must be a Python dictionary") 16 | for k, v in config_dict.items(): 17 | k_parts = k.split('.') 18 | pointer = opt 19 | for k_part in k_parts[:-1]: 20 | if k_part not in pointer: 21 | pointer[k_part] = {} 22 | pointer = pointer[k_part] 23 | assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." 24 | ori_value = pointer.get(k_parts[-1]) 25 | pointer[k_parts[-1]] = v 26 | if ori_value: 27 | logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") 28 | 29 | 30 | def load_opt_from_config_files(conf_files): 31 | """ 32 | Load opt from the config files, settings in later files can override those in previous files. 33 | 34 | Args: 35 | conf_files (list): a list of config file paths 36 | 37 | Returns: 38 | dict: a dictionary of opt settings 39 | """ 40 | opt = {} 41 | for conf_file in conf_files: 42 | with open(conf_file, encoding='utf-8') as f: 43 | config_dict = yaml.safe_load(f) 44 | 45 | load_config_dict_to_opt(opt, config_dict) 46 | 47 | return opt 48 | 49 | 50 | def load_opt_command(args): 51 | parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.') 52 | parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') 53 | parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).') 54 | parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.') 55 | parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') 56 | parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER) 57 | 58 | cmdline_args = parser.parse_args() if not args else parser.parse_args(args) 59 | 60 | opt = load_opt_from_config_files(cmdline_args.conf_files) 61 | 62 | if cmdline_args.config_overrides: 63 | config_overrides_string = ' '.join(cmdline_args.config_overrides) 64 | logger.warning(f"Command line config overrides: {config_overrides_string}") 65 | config_dict = json.loads(config_overrides_string) 66 | load_config_dict_to_opt(opt, config_dict) 67 | 68 | if cmdline_args.overrides: 69 | assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value" 70 | keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] 71 | vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] 72 | vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] 73 | 74 | types = [] 75 | for key in keys: 76 | key = key.split('.') 77 | ele = opt.copy() 78 | while len(key) > 0: 79 | ele = ele[key.pop(0)] 80 | types.append(type(ele)) 81 | 82 | config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} 83 | load_config_dict_to_opt(opt, config_dict) 84 | 85 | # combine cmdline_args into opt dictionary 86 | for key, val in cmdline_args.__dict__.items(): 87 | if val is not None: 88 | opt[key] = val 89 | 90 | return opt, cmdline_args -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | class Entity(object): 3 | def __init__(self, _id, _text, _mask, _interactive, _type, _start_idx, _end_idx, _image=None): 4 | self.id = _id 5 | self.text = _text 6 | self.mask = _mask 7 | self.interactive = _interactive 8 | self.type = _type 9 | self.start_idx = _start_idx 10 | self.end_idx = _end_idx 11 | 12 | self.image = _image -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import pickle 5 | import subprocess 6 | 7 | from mpi4py import MPI 8 | import torch.distributed as dist 9 | 10 | 11 | def apply_distributed(opt): 12 | if opt['rank'] == 0: 13 | hostname_cmd = ["hostname -I"] 14 | result = subprocess.check_output(hostname_cmd, shell=True) 15 | master_address = result.decode('utf-8').split()[0] 16 | master_port = opt['PORT'] 17 | else: 18 | master_address = None 19 | master_port = None 20 | 21 | master_address = MPI.COMM_WORLD.bcast(master_address, root=0) 22 | master_port = MPI.COMM_WORLD.bcast(master_port, root=0) 23 | 24 | if torch.distributed.is_available() and opt['world_size'] > 1: 25 | init_method_url = 'tcp://{}:{}'.format(master_address, master_port) 26 | backend = 'nccl' 27 | world_size = opt['world_size'] 28 | rank = opt['rank'] 29 | torch.distributed.init_process_group(backend=backend, 30 | init_method=init_method_url, 31 | world_size=world_size, 32 | rank=rank) 33 | 34 | def init_distributed(opt): 35 | opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() 36 | comm = MPI.COMM_WORLD 37 | if 'OMPI_COMM_WORLD_SIZE' not in os.environ and comm.Get_size() == 1: 38 | # application was started without MPI 39 | # default to single node with single process 40 | opt['env_info'] = 'no MPI' 41 | opt['world_size'] = 1 42 | opt['local_size'] = 1 43 | opt['rank'] = 0 44 | opt['local_rank'] = 0 45 | opt['master_address'] = '127.0.0.1' 46 | opt['master_port'] = '8673' 47 | else: 48 | # application was started with MPI 49 | # get MPI parameters 50 | if 'OMPI_COMM_WORLD_SIZE' in os.environ: 51 | opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) 52 | opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) 53 | opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) 54 | opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 55 | else: 56 | opt['world_size'] = comm.Get_size() 57 | opt['local_size'] = opt['world_size'] 58 | opt['rank'] = comm.Get_rank() 59 | opt['local_rank'] = opt['rank'] 60 | 61 | # set up device 62 | if not opt['CUDA']: 63 | assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend' 64 | opt['device'] = torch.device("cpu") 65 | else: 66 | torch.cuda.set_device(opt['local_rank']) 67 | opt['device'] = torch.device("cuda", opt['local_rank']) 68 | 69 | apply_distributed(opt) 70 | return opt 71 | 72 | def is_main_process(): 73 | rank = 0 74 | if 'OMPI_COMM_WORLD_SIZE' in os.environ: 75 | rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 76 | else: 77 | comm = MPI.COMM_WORLD 78 | if comm.Get_size() > 1: 79 | rank = comm.Get_rank() 80 | return rank == 0 81 | 82 | def get_world_size(): 83 | if not dist.is_available(): 84 | return 1 85 | if not dist.is_initialized(): 86 | return 1 87 | return dist.get_world_size() 88 | 89 | def get_rank(): 90 | if not dist.is_available(): 91 | return 0 92 | if not dist.is_initialized(): 93 | return 0 94 | return dist.get_rank() 95 | 96 | 97 | def synchronize(): 98 | """ 99 | Helper function to synchronize (barrier) among all processes when 100 | using distributed training 101 | """ 102 | if not dist.is_available(): 103 | return 104 | if not dist.is_initialized(): 105 | return 106 | world_size = dist.get_world_size() 107 | rank = dist.get_rank() 108 | if world_size == 1: 109 | return 110 | 111 | def _send_and_wait(r): 112 | if rank == r: 113 | tensor = torch.tensor(0, device="cuda") 114 | else: 115 | tensor = torch.tensor(1, device="cuda") 116 | dist.broadcast(tensor, r) 117 | while tensor.item() == 1: 118 | time.sleep(1) 119 | 120 | _send_and_wait(0) 121 | # now sync on the main process 122 | _send_and_wait(1) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import math 8 | 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value.""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1, decay=0): 23 | self.val = val 24 | if decay: 25 | alpha = math.exp(-n / decay) # exponential decay over 100 updates 26 | self.sum = alpha * self.sum + (1 - alpha) * val * n 27 | self.count = alpha * self.count + (1 - alpha) * n 28 | else: 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import pickle 5 | import torch 6 | import torch.nn as nn 7 | 8 | from utils.distributed import is_main_process 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | NORM_MODULES = [ 14 | torch.nn.BatchNorm1d, 15 | torch.nn.BatchNorm2d, 16 | torch.nn.BatchNorm3d, 17 | torch.nn.SyncBatchNorm, 18 | # NaiveSyncBatchNorm inherits from BatchNorm2d 19 | torch.nn.GroupNorm, 20 | torch.nn.InstanceNorm1d, 21 | torch.nn.InstanceNorm2d, 22 | torch.nn.InstanceNorm3d, 23 | torch.nn.LayerNorm, 24 | torch.nn.LocalResponseNorm, 25 | ] 26 | 27 | def register_norm_module(cls): 28 | NORM_MODULES.append(cls) 29 | return cls 30 | 31 | def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): 32 | model_keys = sorted(model_state_dict.keys()) 33 | ckpt_keys = sorted(ckpt_state_dict.keys()) 34 | result_dicts = {} 35 | matched_log = [] 36 | unmatched_log = [] 37 | unloaded_log = [] 38 | for model_key in model_keys: 39 | model_weight = model_state_dict[model_key] 40 | if model_key in ckpt_keys: 41 | ckpt_weight = ckpt_state_dict[model_key] 42 | if model_weight.shape == ckpt_weight.shape: 43 | result_dicts[model_key] = ckpt_weight 44 | ckpt_keys.pop(ckpt_keys.index(model_key)) 45 | matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 46 | else: 47 | unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 48 | else: 49 | unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) 50 | 51 | if is_main_process(): 52 | for info in matched_log: 53 | logger.info(info) 54 | for info in unloaded_log: 55 | logger.warning(info) 56 | for key in ckpt_keys: 57 | logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) 58 | for info in unmatched_log: 59 | logger.warning(info) 60 | return result_dicts -------------------------------------------------------------------------------- /utils/prompt_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_llm_prompt_templates(): 4 | prompt_templates_llama = [ 5 | "{}.", 6 | "Describe the concept of {}.", 7 | "An image of {}.", 8 | "Describe {}.", 9 | "Describe visually a {}.", 10 | "Describe an image of {}.", 11 | "Describe visually an image of {}.", 12 | "Describe specifically and visually an image of {}.", 13 | "How to determine visually an {}.", 14 | "Let's describe an image of {}.", 15 | "Let's explain an image of {}.", 16 | "I am seeing an image of {}." 17 | ] 18 | return prompt_templates_llama 19 | 20 | def get_prompt_templates(): 21 | prompt_templates = [ 22 | '{}.', 23 | 'a photo of a {}.', 24 | 'a bad photo of a {}.', 25 | 'a photo of many {}.', 26 | 'a sculpture of a {}.', 27 | 'a photo of the hard to see {}.', 28 | 'a low resolution photo of the {}.', 29 | 'a rendering of a {}.', 30 | 'graffiti of a {}.', 31 | 'a bad photo of the {}.', 32 | 'a cropped photo of the {}.', 33 | 'a tattoo of a {}.', 34 | 'the embroidered {}.', 35 | 'a photo of a hard to see {}.', 36 | 'a bright photo of a {}.', 37 | 'a photo of a clean {}.', 38 | 'a photo of a dirty {}.', 39 | 'a dark photo of the {}.', 40 | 'a drawing of a {}.', 41 | 'a photo of my {}.', 42 | 'the plastic {}.', 43 | 'a photo of the cool {}.', 44 | 'a close-up photo of a {}.', 45 | 'a black and white photo of the {}.', 46 | 'a painting of the {}.', 47 | 'a painting of a {}.', 48 | 'a pixelated photo of the {}.', 49 | 'a sculpture of the {}.', 50 | 'a bright photo of the {}.', 51 | 'a cropped photo of a {}.', 52 | 'a plastic {}.', 53 | 'a photo of the dirty {}.', 54 | 'a jpeg corrupted photo of a {}.', 55 | 'a blurry photo of the {}.', 56 | 'a photo of the {}.', 57 | 'a good photo of the {}.', 58 | 'a rendering of the {}.', 59 | 'a {} in a video game.', 60 | 'a photo of one {}.', 61 | 'a doodle of a {}.', 62 | 'a close-up photo of the {}.', 63 | 'the origami {}.', 64 | 'the {} in a video game.', 65 | 'a sketch of a {}.', 66 | 'a doodle of the {}.', 67 | 'a origami {}.', 68 | 'a low resolution photo of a {}.', 69 | 'the toy {}.', 70 | 'a rendition of the {}.', 71 | 'a photo of the clean {}.', 72 | 'a photo of a large {}.', 73 | 'a rendition of a {}.', 74 | 'a photo of a nice {}.', 75 | 'a photo of a weird {}.', 76 | 'a blurry photo of a {}.', 77 | 'a cartoon {}.', 78 | 'art of a {}.', 79 | 'a sketch of the {}.', 80 | 'a embroidered {}.', 81 | 'a pixelated photo of a {}.', 82 | 'itap of the {}.', 83 | 'a jpeg corrupted photo of the {}.', 84 | 'a good photo of a {}.', 85 | 'a plushie {}.', 86 | 'a photo of the nice {}.', 87 | 'a photo of the small {}.', 88 | 'a photo of the weird {}.', 89 | 'the cartoon {}.', 90 | 'art of the {}.', 91 | 'a drawing of the {}.', 92 | 'a photo of the large {}.', 93 | 'a black and white photo of a {}.', 94 | 'the plushie {}.', 95 | 'a dark photo of a {}.', 96 | 'itap of a {}.', 97 | 'graffiti of the {}.', 98 | 'a toy {}.', 99 | 'itap of my {}.', 100 | 'a photo of a cool {}.', 101 | 'a photo of a small {}.', 102 | 'a tattoo of the {}.', 103 | ] 104 | return prompt_templates 105 | 106 | def prompt_engineering(classnames, topk=1, suffix='.'): 107 | prompt_templates = get_prompt_templates() 108 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 109 | 110 | if isinstance(classnames, list): 111 | classname = random.choice(classnames) 112 | else: 113 | classname = classnames 114 | 115 | return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) 116 | 117 | def prompt_engineering_llm(classnames, topk=1): 118 | prompt_templates = get_llm_prompt_templates() 119 | 120 | if isinstance(classnames, list): 121 | outputs = [] 122 | for cls_name in classnames: 123 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 124 | outputs += [prompt_templates[temp_idx].format(cls_name)] 125 | return outputs 126 | else: 127 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 128 | return prompt_templates[temp_idx].format(classnames) -------------------------------------------------------------------------------- /xy_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/xy_utils/__init__.py -------------------------------------------------------------------------------- /xy_utils/annotation/find_bench_stat.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # entity_val2017.json, entity_val2017_long.json, entity_train2017.json 4 | annot_root = '/nobackup3/xueyan-data/grin_data/coco/annotations/entity_val2017_long.json' 5 | annotations = json.load(open(annot_root, 'r')) 6 | 7 | print("image number: {}".format(len(annotations['images']))) 8 | print("caption number: {}".format(len(annotations['annotations']))) 9 | 10 | entity_count = 0 11 | for annot in annotations['annotations']: 12 | entity_count += len(annot['phrase']) 13 | 14 | print("entity number: {}".format(entity_count)) -------------------------------------------------------------------------------- /xy_utils/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/xy_utils/evaluation/__init__.py -------------------------------------------------------------------------------- /xy_utils/evaluation/compute_grin_visual_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | pth = '/'.join(sys.path[0].split('/')[:-2]) 5 | sys.path.insert(0, pth) 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode 11 | from utils.arguments import load_opt_command 12 | from trainer import XDecoder_Trainer as Trainer 13 | from trainer.utils.misc import move_batch_to_device, cast_batch_to_half 14 | from datasets.evaluation import GroundingEvaluator 15 | from modeling.modules import sem_seg_postprocess 16 | from modeling.language.loss import vl_similarity 17 | from tqdm import tqdm 18 | 19 | from utils.constants import COCO_PANOPTIC_CLASSES 20 | 21 | def main(args=None): 22 | ''' 23 | build args 24 | ''' 25 | opt, cmdline_args = load_opt_command(args) 26 | if cmdline_args.user_dir: 27 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 28 | opt['user_dir'] = absolute_user_dir 29 | 30 | # META DATA 31 | pretrained_pth = opt['RESUME_FROM'] 32 | # hard code interactive token number 33 | opt['DATASETS']['TEST'] = ['grounding_coco_entity_val', 'grounding_coco_entity_val_long'] 34 | 35 | trainer = Trainer(opt) 36 | raw_models = trainer.pipeline.initialize_model() 37 | model = raw_models['default'].from_pretrained(pretrained_pth).eval() 38 | model = model.cuda() 39 | model.model.sem_seg_head.predictor.lang_encoder.activate() 40 | model.model.get_class_embeddings(['default', 'default'], is_eval=True) 41 | 42 | dataset_name = 'grounding_coco_entity_val' 43 | dataloader = trainer.pipeline.get_dataloaders(trainer, dataset_name, is_evaluation=True) 44 | 45 | class_emb_dict = {} 46 | 47 | def inference_visual(entity, extra, _images, height, width): 48 | features = model.model.backbone(_images.tensor) 49 | mask_features, transformer_encoder_features, multi_scale_features = model.model.sem_seg_head.pixel_decoder.forward_features(features) 50 | 51 | extra = {} 52 | extra['spatial_query_pos_mask'] = entity.interactive.to(model.model.device)[None,] 53 | extra['spatial_query_neg_mask'] = entity.interactive.to(model.model.device)[None,].clone().detach() & False 54 | extra['spatial_query_indices'] = torch.arange(1, device=model.model.device)[None,] 55 | outputs = model.model.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=None, extra=extra, task='refimg_spatial') 56 | 57 | pred_sq_emb = outputs['pred_pspatials'] 58 | pred_sp_emb = outputs['pred_smaskembs'] 59 | pred_sc_emb = outputs['pred_spatials'] 60 | 61 | scores = (pred_sp_emb @ pred_sq_emb.transpose(1,2))[0,:,0] 62 | matched_id = scores.max(0)[1] 63 | class_emb = pred_sc_emb[0,matched_id,:] 64 | _outputs = {entity.text.item(): class_emb} 65 | return _outputs 66 | 67 | with torch.no_grad(): 68 | with torch.autocast(device_type='cuda', dtype=torch.float16): 69 | for idx, batched_inputs in enumerate(tqdm(dataloader)): 70 | entities = batched_inputs[0]['entities'] 71 | batched_input = batched_inputs[0] 72 | images = [x["image"].to(model.model.device) for x in batched_inputs] 73 | images = [(x - model.model.pixel_mean) / model.model.pixel_std for x in images] 74 | images = ImageList.from_tensors(images, model.model.size_divisibility) 75 | img_bs = images.tensor.shape[0] 76 | 77 | entity_masks = [] 78 | for entity in entities['entities']: 79 | if entity.type == 'visual': 80 | if len(entity.text) == 0: 81 | continue 82 | class_id = entity.text.item() 83 | if class_id in class_emb_dict: 84 | if len(class_emb_dict[class_id]) >= 30: 85 | continue 86 | else: 87 | class_emb_dict[class_id] = [] 88 | outputs = inference_visual(entity, model, images, batched_input['height'], batched_input['width']) 89 | class_id, class_emb = outputs.popitem() 90 | class_emb_dict[class_id].append(class_emb) 91 | 92 | # class_emb_dict = torch.load('class_emb_dict_focalt.da') 93 | class_embeddings = [] 94 | for i in range(len(class_emb_dict)): 95 | class_embeddings += [torch.stack(class_emb_dict[i], dim=0).mean(0)] 96 | class_embeddings = torch.stack(class_embeddings, dim=0) 97 | torch.save(class_embeddings.cpu(), 'class_embeddings_davitd5.da') 98 | 99 | if __name__ == "__main__": 100 | main() 101 | sys.exit(0) -------------------------------------------------------------------------------- /xy_utils/evaluation/eval_gsam_grounding_entity.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | from pycocotools import mask as coco_mask 4 | 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | 11 | pth = '/'.join(sys.path[0].split('/')[:-1]) 12 | sys.path.insert(0, pth) 13 | 14 | from utils.arguments import load_opt_command 15 | from trainer import XDecoder_Trainer as Trainer 16 | from datasets.evaluation.grounding_evaluation import GroundingEvaluator 17 | from trainer.utils.misc import move_batch_to_device, cast_batch_to_half 18 | 19 | 20 | json_path = "/nobackup3/xueyan-data/code/grin/vlcore_content/entity_val2017_gsam_bh_long.json" 21 | annotations = json.load(open(json_path, 'r')) 22 | 23 | def string_to_rle(rle_string): 24 | """ 25 | Converts a string representation of RLE to a dictionary. 26 | 27 | :param rle_string: RLE string. 28 | :return: RLE dictionary. 29 | """ 30 | try: 31 | rle_dict = ast.literal_eval(rle_string) 32 | if isinstance(rle_dict, dict) and 'counts' in rle_dict and 'size' in rle_dict: 33 | return rle_dict 34 | else: 35 | raise ValueError("String does not represent a valid RLE format.") 36 | except: 37 | raise ValueError("Error in converting string to RLE.") 38 | 39 | def rle_to_mask(rle, height, width): 40 | """ 41 | Converts a RLE (run length encoded) mask to a binary mask. 42 | 43 | :param rle: RLE dictionary. 44 | :param height: Height of the mask. 45 | :param width: Width of the mask. 46 | :return: Binary mask. 47 | """ 48 | if isinstance(rle, dict) and 'counts' in rle and 'size' in rle: 49 | rle = [rle] 50 | else: 51 | raise ValueError("RLE format not recognized.") 52 | 53 | mask_decoded = coco_mask.decode(rle) 54 | return mask_decoded 55 | 56 | def inverse_sigmoid(mask, epsilon=1e-6): 57 | """ 58 | Apply inverse sigmoid (logit) transformation to a mask. 59 | 60 | :param mask: Binary mask. 61 | :param epsilon: Small value to avoid division by zero or log of zero. 62 | :return: Transformed mask. 63 | """ 64 | # Ensure mask values are in the range (0, 1) 65 | mask = np.clip(mask, epsilon, 1 - epsilon) 66 | 67 | # Apply inverse sigmoid (logit) 68 | transformed_mask = np.log(mask / (1 - mask)) 69 | return transformed_mask 70 | 71 | def main(args=None): 72 | opt, cmdline_args = load_opt_command(args) 73 | if cmdline_args.user_dir: 74 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 75 | opt['user_dir'] = absolute_user_dir 76 | 77 | # META DATA 78 | trainer = Trainer(opt) 79 | dataset_name = 'grounding_coco_entity_val_long' 80 | opt['DATASETS']['TEST'] = [dataset_name] 81 | dataloader = trainer.pipeline.get_dataloaders(trainer, dataset_name, is_evaluation=True) 82 | 83 | evaluator = GroundingEvaluator(dataset_name) 84 | evaluator.reset() 85 | 86 | index = 0 87 | for annot, batched_inputs in zip(annotations, dataloader): 88 | batched_inputs = move_batch_to_device(batched_inputs, 'cuda') 89 | height = annot['height'] 90 | width = annot['width'] 91 | processed_results = [] 92 | acc_masks = [] 93 | for phrase in annot['phrase']: 94 | rle = phrase['gsam_output']['mask'] 95 | rle = string_to_rle(rle) 96 | mask = rle_to_mask(rle, height, width) 97 | mask = torch.from_numpy(mask).permute(2,0,1) 98 | acc_masks.append(mask.cuda()) 99 | 100 | processed_results.append({}) 101 | acc_masks = torch.cat(acc_masks, dim=0) 102 | processed_results[-1]['grounding_mask'] = acc_masks 103 | evaluator.process(batched_inputs, processed_results) 104 | index += 1 105 | print(index, len(annotations)) 106 | print(evaluator.evaluate()) 107 | 108 | if __name__ == "__main__": 109 | main() 110 | sys.exit(0) -------------------------------------------------------------------------------- /xy_utils/evaluation/eval_xdecoder_interleave_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | pth = '/'.join(sys.path[0].split('/')[:-2]) 5 | sys.path.insert(0, pth) 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode 11 | from utils.arguments import load_opt_command 12 | from trainer import XDecoder_Trainer as Trainer 13 | from trainer.utils.misc import move_batch_to_device, cast_batch_to_half 14 | from datasets.evaluation import RetrievalEvaluator 15 | from tqdm import tqdm 16 | 17 | def main(args=None): 18 | ''' 19 | build args 20 | ''' 21 | opt, cmdline_args = load_opt_command(args) 22 | if cmdline_args.user_dir: 23 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 24 | opt['user_dir'] = absolute_user_dir 25 | 26 | # META DATA 27 | pretrained_pth = opt['RESUME_FROM'] 28 | database_root = "../../data/output/database" 29 | coco_folders = ["/nobackup3/xueyan-data/grin_data/coco/train2017", "/nobackup3/xueyan-data/grin_data/coco/val2017", database_root] 30 | # paragraph_path = "/nobackup3/xueyan-data/grin_data/coco/annotations/entity_val2017_long.json" 31 | add_image_pths = [] 32 | add_image_id = 0 33 | 34 | # hard code interactive token number 35 | opt['DATASETS']['TEST'] = ['vlp_coco_interleave_val', 'vlp_coco_interleave_val_long'] 36 | 37 | trainer = Trainer(opt) 38 | raw_models = trainer.pipeline.initialize_model() 39 | model = raw_models['default'].from_pretrained(pretrained_pth).eval() 40 | model = model.cuda() 41 | 42 | dataset_name = 'vlp_coco_interleave_val' 43 | dataloader = trainer.pipeline.get_dataloaders(trainer, dataset_name, is_evaluation=True) 44 | 45 | #Remove retrieval evaluator input? 46 | evaluator = RetrievalEvaluator(dataset_name, None) 47 | evaluator.reset() 48 | 49 | with torch.no_grad(): 50 | with torch.autocast(device_type='cuda', dtype=torch.float16): 51 | for idx, batched_inputs in enumerate(tqdm(dataloader)): 52 | processed_results = [] 53 | processed_results.append({}) 54 | assert len(batched_inputs) == 1 55 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["default", "default"], is_eval=True) 56 | 57 | batched_input = batched_inputs[0] 58 | images = [x["image"].to(model.model.device) for x in batched_inputs] 59 | images = [(x - model.model.pixel_mean) / model.model.pixel_std for x in images] 60 | images = ImageList.from_tensors(images, model.model.size_divisibility) 61 | img_bs = images.tensor.shape[0] 62 | 63 | targets = targets_grounding = queries_grounding = None 64 | features = model.model.backbone(images.tensor) 65 | outputs = model.model.sem_seg_head(features, target_queries=queries_grounding) 66 | v_emb_it = outputs['pred_captions'][:,-1] 67 | image_embeds = [v_emb_it] 68 | 69 | lang_results = model.model.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_input['entities']['sentence']) 70 | t_emb_it = lang_results['class_emb'] 71 | caption_ids = [batched_input['image_id']] 72 | 73 | caption_results = { 74 | 'image_embeds': image_embeds, 75 | 'text_embeds': t_emb_it, 76 | 'caption_ids': caption_ids, 77 | 'image_ids': batched_input['image_id'], 78 | } 79 | processed_results[-1]["caption"] = caption_results 80 | evaluator.process(None, processed_results) 81 | 82 | print(f"{dataset_name} Results: {evaluator.evaluate()}") 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | sys.exit(0) -------------------------------------------------------------------------------- /xy_utils/gpt4/generate_class_description.py: -------------------------------------------------------------------------------- 1 | #Note: The openai-python library support for Azure OpenAI is in preview. 2 | import sys 3 | pth = '/'.join(sys.path[0].split('/')[:-1]) 4 | sys.path.insert(0, pth) 5 | 6 | import os 7 | import openai 8 | import time 9 | import functools 10 | import signal 11 | import torch 12 | 13 | from utils.constants import COCO_PANOPTIC_CLASSES 14 | 15 | 16 | def run(class_name, openai_version): 17 | 18 | if openai_version == 0: 19 | openai.api_type = "azure" 20 | openai.api_base = "" 21 | openai.api_version = "2023-03-15-preview" 22 | openai.api_key = os.getenv("OPENAI_API_KEY") 23 | deployment_id = "gpt4" 24 | elif openai_version == 1: 25 | openai.api_type = "azure" 26 | openai.api_base = "" 27 | openai.api_version = "2023-03-15-preview" 28 | openai.api_key = os.getenv("OPENAI_API_KEY") 29 | deployment_id = "gpt4a" 30 | elif openai_version == 2: 31 | openai.api_key = os.getenv("OPENAI_API_KEY_AZURE") 32 | openai.api_base ='' # your endpoint should look like the following https://YOUR_RESOURCE_NAME.openai.azure.com/ 33 | openai.api_type = 'azure' 34 | openai.api_version = '2023-03-15-preview' # this may change in the future 35 | deployment_id='gpt-4-32k-0314' #This will correspond to the custom name you chose for your deployment when you deployed a model. 36 | elif openai_version == 3: 37 | openai.api_base ='' # your endpoint should look like the following https://YOUR_RESOURCE_NAME.openai.azure.com/ 38 | openai.api_type = 'azure' 39 | openai.api_key = os.getenv("OPENAI_API_KEY") 40 | openai.api_version = '2023-07-01-preview' # this may change in the future 41 | deployment_id='gpt-4-32k-0613' #This will correspond to the custom name you chose for your deployment when you deployed a model. 42 | else: 43 | print(openai_version) 44 | assert False 45 | 46 | 47 | content = ''' 48 | Describe {} in a long sentence without any word contains its name. 49 | '''.format(class_name) 50 | 51 | def timeout(seconds, error_message = 'OpenAI call timed out'): 52 | def decorated(func): 53 | def _handle_timeout(signum, frame): 54 | raise TimeoutError(error_message) 55 | 56 | def wrapper(*args, **kwargs): 57 | signal.signal(signal.SIGALRM, _handle_timeout) 58 | signal.alarm(seconds) 59 | try: 60 | result = func(*args, **kwargs) 61 | finally: 62 | signal.alarm(0) 63 | return result 64 | 65 | return functools.wraps(func)(wrapper) 66 | return decorated 67 | 68 | @timeout(300) 69 | def openai_call(deployment_id, prompt): 70 | try: 71 | response = openai.ChatCompletion.create( 72 | engine=deployment_id, 73 | max_tokens=1500, 74 | temperature=0., 75 | messages=[{"role":"user", "content":prompt}]) 76 | return response 77 | except Exception as e: 78 | if 'triggering Azure OpenAI’s content management policy' in str(e): 79 | return 'continue' 80 | else: 81 | raise 82 | 83 | response = openai_call(deployment_id, content) 84 | # print(response) 85 | return response['choices'][0]['message']['content'] 86 | 87 | if __name__ == "__main__": 88 | coco_classes = [x.replace('-other','').replace('-merged','').replace('-stuff','') for x in COCO_PANOPTIC_CLASSES] 89 | coco_description = [] 90 | openai_version = 0 91 | for cnt, class_name in enumerate(coco_classes): 92 | print(cnt) 93 | success = False 94 | while not success: 95 | print('hhihi') 96 | try: 97 | response = run(class_name, openai_version) 98 | coco_description.append(response) 99 | success = True 100 | except: 101 | traceback.print_exc() 102 | openai_version += 1 103 | success = False 104 | 105 | torch.save(coco_description, 'coco_description.pt') -------------------------------------------------------------------------------- /xy_utils/image2html/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UX-Decoder/FIND/521c934a6985b8712ae7cb9a15b3156020c3a797/xy_utils/image2html/__init__.py -------------------------------------------------------------------------------- /xy_utils/image2html/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def writeHTML(file_name, im_paths, captions, height=200, width=200): 4 | f=open(file_name, 'w') 5 | html=[] 6 | f.write('\n') 7 | f.write('\n') 8 | f.write('\n') 9 | for row in range(len(im_paths)): 10 | f.write('\n') 11 | for col in range(len(im_paths[row])): 12 | f.write('') 15 | f.write(' ') 16 | f.write('\n\n') 17 | 18 | f.write('\n') 19 | for col in range(len(im_paths[row])): 20 | f.write('') 23 | f.write(' ') 24 | f.write('\n\n') 25 | f.write('

') 26 | f.write('
') 13 | f.write(captions[row][col]) 14 | f.write('
\n') 27 | f.close() 28 | 29 | def writeSeqHTML(file_name, im_paths, captions, col_n, height=200, width=200): 30 | total_n = len(im_paths) 31 | row_n = int(math.ceil(float(total_n) / col_n)) 32 | f=open(file_name, 'w') 33 | html=[] 34 | f.write('\n') 35 | f.write('\n') 36 | f.write('\n') 37 | for row in range(row_n): 38 | base_count = row * col_n 39 | f.write('\n') 40 | for col in range(col_n): 41 | if base_count + col < total_n: 42 | f.write('') 45 | f.write(' ') 46 | f.write('\n\n') 47 | 48 | f.write('\n') 49 | for col in range(col_n): 50 | if base_count + col < total_n: 51 | f.write('') 54 | f.write(' ') 55 | f.write('\n\n') 56 | f.write('

') 57 | f.write('
') 43 | f.write(captions[base_count + col]) 44 | f.write('
\n') 58 | f.close() --------------------------------------------------------------------------------