├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── __init__.py ├── assets ├── readmes │ ├── DATASET.md │ ├── EVAL.md │ ├── INFERENCE.md │ ├── INSTALL.md │ └── TRAIN.md ├── requirements │ ├── requirements.txt │ └── requirements_custom.txt └── scripts │ └── run_demo.sh ├── configs ├── seem │ ├── davitd3_unicl_lang_v1.yaml │ ├── davitd5_unicl_lang_v1.yaml │ ├── focall_unicl_lang_demo.yaml │ ├── focall_unicl_lang_v0.yaml │ ├── focall_unicl_lang_v1.yaml │ ├── focalt_unicl_lang_demo.yaml │ ├── focalt_unicl_lang_v0.yaml │ ├── focalt_unicl_lang_v1.yaml │ ├── samvitb_unicl_lang_v1.yaml │ └── samvitl_unicl_lang_v1.yaml └── xdecoder │ ├── davitd3_unicl_lang.yaml │ ├── davitd5_unicl_lang.yaml │ ├── focall_unicl_lang.yaml │ ├── focalt_unicl_lang.yaml │ ├── xdecoder_focall_lang.yaml │ └── xdecoder_focalt_lang.yaml ├── datasets ├── __init__.py ├── build.py ├── dataset_mappers │ ├── __init__.py │ ├── bdd_semseg_dataset_mapper.py │ ├── coco_instance_new_baseline_dataset_mapper.py │ ├── coco_panoptic_interactive_dataset_mapper.py │ ├── coco_panoptic_new_baseline_dataset_mapper.py │ ├── imagenet_dataset_mapper.py │ ├── mask_former_instance_dataset_mapper.py │ ├── mask_former_panoptic_dataset_mapper.py │ ├── mask_former_semantic_dataset_mapper.py │ ├── pascalvoc_dataset_mapper_ix.py │ ├── refcoco_dataset_mapper.py │ ├── scannet_dataset_mapper.py │ ├── scannet_pano_dataset_mapper.py │ ├── sunrgbd_dataset_mapper.py │ └── vlp_dataset_mapper.py ├── evaluation │ ├── __init__.py │ ├── captioning_evaluation.py │ ├── classification_evaluation.py │ ├── grounding_evaluation.py │ ├── instance_evaluation.py │ ├── interactive_evaluation.py │ ├── panoptic_evaluation.py │ ├── retrieval_evaluation.py │ └── segmentation_evaluation.py ├── refer.py ├── registration │ ├── __init__.py │ ├── register_ade20k_full.py │ ├── register_ade20k_instance.py │ ├── register_ade20k_panoptic.py │ ├── register_bdd100k_panoseg.py │ ├── register_bdd100k_semseg.py │ ├── register_coco_lvis_panoptic_annos_caption_grounding.py │ ├── register_coco_panoptic_annos_caption.py │ ├── register_coco_panoptic_annos_caption_grounding.py │ ├── register_coco_panoptic_annos_semseg.py │ ├── register_coco_stuff_10k.py │ ├── register_imagenet_cls.py │ ├── register_pascalvoc_eval.py │ ├── register_refcoco_dataset.py │ ├── register_scannet_panoptic.py │ ├── register_scannet_semseg.py │ ├── register_sunrgbd_semseg.py │ └── register_vlp_datasets.py ├── semseg_loader.py ├── utils │ ├── refcoco2json.py │ └── refer.py └── visual_sampler │ ├── __init__.py │ ├── circle.py │ ├── mask_generators.py │ ├── point.py │ ├── polygon.py │ ├── sampler.py │ ├── scribble.py │ └── simpleclick_sampler.py ├── demo ├── __init__.py └── seem │ ├── __init__.py │ ├── app.py │ ├── examples │ ├── corgi1.webp │ ├── corgi2.jpg │ ├── fries1.png │ ├── fries2.png │ ├── minecraft1.jpg │ ├── placeholder.png │ ├── ref_vase.JPG │ ├── river1.png │ ├── river1.wav │ ├── river1_mask.png │ ├── river2.png │ ├── vasedeck.mp4 │ ├── zebras1.jpg │ └── zebras2.jpg │ └── tasks │ ├── __init__.py │ └── interactive.py ├── entry.py ├── inference ├── __init__.py ├── images │ ├── animals.png │ ├── apples.jpg │ ├── coco │ │ ├── 000.jpg │ │ ├── 001.jpg │ │ ├── 002.jpg │ │ └── 003.jpg │ ├── fruit.jpg │ ├── landscape.jpg │ ├── mountain.jpeg │ ├── owls.jpeg │ ├── penguin.jpeg │ ├── region_retrieval.png │ ├── rose.webp │ ├── street.jpg │ └── teaser_new.png └── xdecoder │ ├── infer_captioning.py │ ├── infer_instseg.py │ ├── infer_panoseg.py │ ├── infer_refseg.py │ ├── infer_region_retrieval.py │ └── infer_semseg.py ├── modeling ├── BaseModel.py ├── __init__.py ├── architectures │ ├── __init__.py │ ├── build.py │ ├── registry.py │ ├── seem_model_demo.py │ ├── seem_model_v0.py │ ├── seem_model_v1.py │ └── xdecoder_model.py ├── backbone │ ├── __init__.py │ ├── backbone.py │ ├── build.py │ ├── focal.py │ ├── focal_dw.py │ └── registry.py ├── body │ ├── __init__.py │ ├── build.py │ ├── decoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── modules.py │ │ ├── registry.py │ │ └── xdecoder.py │ ├── encoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── registry.py │ │ └── transformer_encoder_fpn.py │ ├── registry.py │ ├── transformer_blocks.py │ └── xdecoder_head.py ├── interface │ ├── __init__.py │ ├── build.py │ ├── modules.py │ ├── prototype │ │ ├── __init__.py │ │ ├── attention_data_struct_seemdemo.py │ │ ├── attention_data_struct_seemv0.py │ │ └── attention_data_struct_seemv1.py │ ├── seem_demo.py │ ├── seem_v0.py │ ├── seem_v1.py │ └── xdecoder.py ├── language │ ├── LangEncoder │ │ ├── __init__.py │ │ ├── build.py │ │ ├── registry.py │ │ └── transformer.py │ ├── __init__.py │ ├── build.py │ ├── loss.py │ ├── misc.py │ ├── registry.py │ └── vlpencoder.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── criterion.py │ ├── matcher.py │ ├── point_features.py │ ├── position_encoding.py │ └── postprocessing.py ├── utils │ ├── __init__.py │ ├── attention.py │ ├── box_ops.py │ ├── config.py │ ├── interactive.py │ └── misc.py └── vision │ ├── backbone │ ├── __init__.py │ ├── backbone.py │ ├── build.py │ ├── common.py │ ├── davit.py │ ├── focal.py │ ├── focal_dw.py │ └── vit.py │ └── encoder │ ├── __init__.py │ ├── build.py │ ├── ops │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn_cpu.h │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp │ └── test.py │ ├── transformer_blocks.py │ ├── transformer_encoder_deform.py │ └── transformer_encoder_fpn.py ├── pipeline ├── XDecoderPipeline.py ├── __init__.py └── utils │ └── misc.py ├── trainer ├── __init__.py ├── default_trainer.py ├── distributed_trainer.py ├── utils │ ├── __init__.py │ ├── hook.py │ ├── misc.py │ ├── mpi_adapter.py │ └── serialization.py ├── utils_trainer.py └── xdecoder_trainer.py └── utils ├── Config.py ├── __init__.py ├── arguments.py ├── constants.py ├── dataset.py ├── distributed.py ├── misc.py ├── model.py ├── prompt_engineering.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ project files 2 | .idea 3 | *.iml 4 | out 5 | gen 6 | 7 | ### Vim template 8 | [._]*.s[a-w][a-z] 9 | [._]s[a-w][a-z] 10 | *.un~ 11 | Session.vim 12 | .netrwhist 13 | *~ 14 | 15 | ### IPythonNotebook template 16 | # Temporary data 17 | .ipynb_checkpoints/ 18 | 19 | ### Python template 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | env/ 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | #lib/ 38 | #lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *,cover 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | *.ipynb 80 | *.params 81 | # *.json 82 | .vscode/ 83 | *.code-workspace/ 84 | 85 | lib/pycocotools/_mask.c 86 | lib/nms/cpu_nms.c 87 | 88 | OUTPUT 89 | OUTPUT/* 90 | models/* 91 | DATASET 92 | DATASET/* 93 | external/ 94 | MODELS 95 | MODELS/* 96 | 97 | kill.sh 98 | train.sh 99 | 100 | draws/ 101 | plot/ 102 | 103 | *run.sh 104 | exps/* 105 | amlt/* 106 | 107 | *venv/* 108 | *.pt 109 | *.pth 110 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/__init__.py -------------------------------------------------------------------------------- /assets/readmes/DATASET.md: -------------------------------------------------------------------------------- 1 | # Preparing Dataset 2 | 3 | :bangbang: The dataset preparation contains many details, welcome community contribution to fix any bug, Thanks! 4 | 5 | Our dataloader follows [Detectron2](https://github.com/facebookresearch/detectron2) that contains:
6 | (1) [A dataset registrator](datasets/registration)
7 | (2) [A dataset mapper](datasets/dataset_mappers)
8 | We modify the dataset registration and mapper for custom datasets. 9 | 10 | ## Training Dataset 11 | We assume all the datasets are stored under: 12 | ``` 13 | .xdecoder_data 14 | ``` 15 | 16 | ### COCO (SEEM & X-Decoder) 17 | 18 | ```sh 19 | # Prepare panoptic_train2017, panoptic_semseg_train2017 exactly the same as [Mask2Fomer](https://github.com/facebookresearch/Mask2Former/tree/main/datasets) 20 | 21 | # (SEEM & X-Decoder) Download additional logistic and custom annotation files to .xdecoder_data/coco/annotations 22 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/caption_class_similarity.pth 23 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/captions_train2017_filtrefgumdval_filtvlp.json 24 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/grounding_train2017_filtrefgumdval_filtvlp.json 25 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/panoptic_train2017_filtrefgumdval_filtvlp.json 26 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/refcocog_umd_val.json 27 | wget https://github.com/peteanderson80/coco-caption/blob/master/annotations/captions_val2014.json 28 | 29 | # (SEEM) Download LVIS annotations for mask preparation 30 | wget https://huggingface.co/xdecoder/SEEM/resolve/main/coco_train2017_filtrefgumdval_lvis.json 31 | ``` 32 | 33 | After dataset preparation, the dataset structure would be: 34 | ``` 35 | .xdecoder_data 36 | └── coco/ 37 | ├── train2017/ 38 | ├── val2017/ 39 | ├── panoptic_train2017/ 40 | ├── panoptic_semseg_train2017/ 41 | ├── panoptic_val2017/ 42 | ├── panoptic_semseg_val2017/ 43 | └── annotations/ 44 | ├── refcocog_umd_val.json 45 | ├── captions_val2014.json 46 | ├── panoptic_val2017.json 47 | ├── caption_class_similarity.pth 48 | ├── panoptic_train2017_filtrefgumdval_filtvlp.json 49 | └── grounding_train2017_filtrefgumdval_filtvlp.json 50 | └── lvis/ 51 | └── coco_train2017_filtrefgumdval_lvis.json 52 | ``` 53 | 54 | #### 4M Image Text Pairs (X-Decoder) 55 | We follow the exact data preparation for the image text pairs data with [ViLT](https://github.com/dandelin/ViLT/blob/master/DATA.md). 56 | ``` 57 | # The pretrained arrow file are put under .xdecoder_data/pretrain_arrows_code224 with the following list of files. 58 | ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] + ["code224_vg.arrow"] + [f"code224_sbu_{i}.arrow" for i in range(9)] + [f"code224_conceptual_caption_train_{i}.arrow" for i in range(31)] 59 | # ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] are originated from ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] with deletion of coco val2017 overlapped images to avoid information leakage. 60 | ``` 61 | 62 | To get quick started: 63 | ```sh 64 | # Download coco karparthy test set (we hack the training data to be coco_caption_karpathy_test.arrow only for quick start in the codebase) 65 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/coco_caption_karpathy_test.arrow 66 | ``` 67 | 68 | After dataset preparation, the dataset structure would be: 69 | ``` 70 | .xdecoder_data 71 | └── pretrain_arrows_code224/ 72 | ├── coco_caption_karpathy_test.arrow 73 | ├── *filtcoco2017val_caption_karpathy_train.arrow 74 | ├── ... 75 | ├── *code224_vg.arrow 76 | ├── *code224_sbu_0.arrow 77 | ├── ... 78 | ├── *code224_conceptual_caption_train_0.arrow 79 | └── ... 80 | * Those datasets are optional for debugging the pipeline. ! NEED to add back when you are training the model. 81 | ``` 82 | 83 | ***NOTE:*** 84 | 85 | 86 | There are overlap between COCO2017, COCO-Karpathy and REF-COCO dataset, and ref-coco is all overlapped with the COCO2017 training data, we have exclude the refcocog-umd validation, coco-karpathy test split during training. 87 | 88 | ## Evaluation Dataset 89 | 90 | ### RefCOCO (SEEM & X-Decoder) 91 | Please refer to COCO Preparation on [line](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/blob/v1.0/assets/readmes/DATASET.md#coco-seem--x-decoder). 92 | 93 | ### ADE20K, Cityscapes (X-Decoder) 94 | Please Refer to [Mask2Former](https://github.com/facebookresearch/Mask2Former/tree/main/datasets). 95 | 96 | ### BDD100K (X-Decoder) 97 | Please download the 10k split of BDD100k at https://doc.bdd100k.com/download.html#id1 98 | 99 | ### PascalVOC and all other interactive evaluation datasets (SEEM) 100 | Please follow the instruction on [RITM](https://github.com/SamsungLabs/ritm_interactive_segmentation) 101 | 102 | After dataset preparation, the dataset structure would be: 103 | ``` 104 | .xdecoder_data 105 | └── PascalVOC/ 106 | ├── Annotations/ 107 | ├── ImageSets 108 | ├── JPEGImages/ 109 | ├── SegmentationClass/ 110 | └── SegmentationObject/ 111 | ``` 112 | 113 | -------------------------------------------------------------------------------- /assets/readmes/INFERENCE.md: -------------------------------------------------------------------------------- 1 | ### Demo 2 | 3 | OpenVocab Semantic Segmentation 4 | ```sh 5 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_semseg.py evaluate \ 6 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \ 7 | --overrides \ 8 | RESUME_FROM /pth/to/xdecoder_focall_best_openseg.pt 9 | ``` 10 | 11 | OpenVocab Instance Segmentation 12 | ```sh 13 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_instseg.py evaluate \ 14 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \ 15 | --overrides \ 16 | RESUME_FROM /pth/to/xdecoder_focall_best_openseg.pt 17 | ``` 18 | 19 | OpenVocab Panoptic Segmentation 20 | ```sh 21 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_panoseg.py evaluate \ 22 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \ 23 | --overrides \ 24 | RESUME_FROM /pth/to/xdecoder_focall_best_openseg.pt 25 | ``` 26 | 27 | OpenVocab Referring Segmentation 28 | ```sh 29 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_refseg.py evaluate \ 30 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \ 31 | --overrides \ 32 | RESUME_FROM /pth/to/xdecoder_focall_last.pt 33 | ``` 34 | 35 | Region Retrieval 36 | ```sh 37 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_region_retrieval.py evaluate \ 38 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \ 39 | --overrides \ 40 | RESUME_FROM /pth/to/xdecoder_focall_last.pt 41 | ``` 42 | 43 | Image Captioning 44 | ```sh 45 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_captioning.py evaluate \ 46 | --conf_files configs/xdecoder/xdecoder_focalt_lang.yaml \ 47 | --overrides \ 48 | RESUME_FROM /pth/to/xdecoder_focalt_last_novg.pt 49 | ``` -------------------------------------------------------------------------------- /assets/readmes/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation Guide 2 | 3 | **General Environment** 4 | * Linux System 5 | * CUDA enabled GPU with Memory > 8GB (Evaluation) 6 | * CUDA enabled GPU with Memory > 12GB (Training) 7 | 8 | **Installation** 9 | 10 | ```sh 11 | # Python Package Installation 12 | pip install -r assets/requirements/requirements.txt 13 | pip install -r assets/requirements/requirements_custom.txt 14 | 15 | # Customer Operator [only need training deformable vision encoder] 16 | cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../ 17 | 18 | # System Package [only need for demo in SEEM] 19 | sudo apt update 20 | sudo apt install ffmpeg 21 | ``` 22 | 23 | **Dataset Preparation** 24 | 25 | Please refer to [DATASET.md](assets/readmes/DATASET.md). 26 | 27 | **Evaluation Tool** 28 | ```sh 29 | # save coco_caption.zip to .xdecoder_data 30 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/coco_caption.zip 31 | unzip coco_caption.zip 32 | ``` 33 | 34 | **Environment Variables** 35 | ```sh 36 | export DETECTRON2_DATASETS=/pth/to/xdecoder_data 37 | export DATASET=/pth/to/xdecoder_data 38 | export DATASET2=/pth/to/xdecoder_data 39 | export VLDATASET=/pth/to/xdecoder_data 40 | export PATH=$PATH:/pth/to/xdecoder_data/coco_caption/jre1.8.0_321/bin 41 | export PYTHONPATH=$PYTHONPATH:/pth/to/xdecoder_data/coco_caption 42 | ``` 43 | 44 | **Pretrained Checkpoint** 45 | 46 | X-Decoder: 47 | ```sh 48 | # Focal-T UniCL 49 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focalt_in21k_yfcc_gcc_xdecoder_unicl.pt 50 | 51 | # Focal-L UniCL 52 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focall_vision_focalb_lang_unicl.pt 53 | ``` 54 | 55 | SEEM: 56 | ``` 57 | # Focal-T X-Decoder 58 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_last.pt 59 | 60 | # Focal-L X-Decoder 61 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focall_last_oq101.pt 62 | 63 | # Focal-B UniCL Language 64 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focalb_lang_unicl.pt 65 | 66 | # ViT-B SAM 67 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 68 | 69 | # ViT-L SAM 70 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 71 | 72 | ``` 73 | 74 | -------------------------------------------------------------------------------- /assets/requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | torchvision==0.16.0 3 | pillow==9.4.0 4 | opencv-python==4.8.1.78 5 | pyyaml==6.0.1 6 | json_tricks==3.17.3 7 | yacs==0.1.8 8 | scikit-learn==1.3.1 9 | pandas==2.0.3 10 | timm==0.4.12 11 | numpy==1.23.1 12 | einops==0.7.0 13 | fvcore==0.1.5.post20221221 14 | transformers==4.34.0 15 | sentencepiece==0.1.99 16 | ftfy==6.1.1 17 | regex==2023.10.3 18 | nltk==3.8.1 19 | mpi4py==3.1.5 20 | vision-datasets==0.2.2 21 | cython==3.0.2 22 | pycocotools==2.0.7 23 | diffdist==0.1 24 | pyarrow==13.0.0 25 | cityscapesscripts==2.2.2 26 | shapely==1.8.0 27 | scikit-image==0.21.0 28 | mup==1.0.0 29 | accelerate==0.23.0 30 | kornia==0.7.0 31 | deepspeed==0.10.3 32 | wandb==0.15.12 33 | infinibatch==0.1.1 34 | gradio==3.42.0 -------------------------------------------------------------------------------- /assets/requirements/requirements_custom.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/arogozhnikov/einops.git 2 | git+https://github.com/MaureenZOU/detectron2-xyz.git 3 | git+https://github.com/openai/whisper.git -------------------------------------------------------------------------------- /assets/scripts/run_demo.sh: -------------------------------------------------------------------------------- 1 | sudo apt update 2 | sudo apt install ffmpeg 3 | pip install -r assets/requirements/requirements.txt 4 | pip install -r assets/requirements/requirements_custom.txt 5 | python demo/seem/demo.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import registration 2 | from .build import build_train_dataloader, build_eval_dataloader, build_evaluator -------------------------------------------------------------------------------- /datasets/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper 2 | from .coco_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper 3 | from .coco_panoptic_new_baseline_dataset_mapper import COCOPanopticNewBaselineDatasetMapper 4 | from .mask_former_instance_dataset_mapper import MaskFormerInstanceDatasetMapper 5 | from .mask_former_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapper 6 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper 7 | from .imagenet_dataset_mapper import ImageNetDatasetMapper 8 | from .vlp_dataset_mapper import VLPreDatasetMapper 9 | from .sunrgbd_dataset_mapper import SunRGBDSegDatasetMapper 10 | from .scannet_dataset_mapper import ScanNetSegDatasetMapper 11 | from .bdd_semseg_dataset_mapper import BDDSemDatasetMapper 12 | from .scannet_pano_dataset_mapper import ScanNetPanoDatasetMapper 13 | from .refcoco_dataset_mapper import RefCOCODatasetMapper 14 | from .pascalvoc_dataset_mapper_ix import PascalVOCSegDatasetMapperIX -------------------------------------------------------------------------------- /datasets/dataset_mappers/bdd_semseg_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["BDDSemDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class BDDSemDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def read_semseg(self, file_name): 76 | if '.png' in file_name: 77 | semseg = np.asarray(Image.open(file_name)) 78 | elif '.mat' in file_name: 79 | semseg = scipy.io.loadmat(file_name)['LabelMap'] 80 | return semseg 81 | 82 | def __call__(self, dataset_dict): 83 | """ 84 | Args: 85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 86 | 87 | Returns: 88 | dict: a format that builtin models in detectron2 accept 89 | """ 90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 91 | file_name = dataset_dict['file_name'] 92 | semseg_name = dataset_dict['sem_seg_file_name'] 93 | image = Image.open(file_name).convert('RGB') 94 | 95 | dataset_dict['width'] = image.size[0] 96 | dataset_dict['height'] = image.size[1] 97 | 98 | if self.is_train == False: 99 | image = self.transform(image) 100 | image = torch.from_numpy(np.asarray(image).copy()) 101 | image = image.permute(2,0,1) 102 | 103 | semseg = self.read_semseg(semseg_name) 104 | semseg = torch.from_numpy(semseg.astype(np.int32)) 105 | dataset_dict['image'] = image 106 | dataset_dict['semseg'] = semseg 107 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/imagenet_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | from PIL import Image 10 | # import logging 11 | 12 | import cv2 13 | import numpy as np 14 | 15 | import torch 16 | from torchvision import transforms 17 | 18 | from modeling.utils import configurable 19 | 20 | __all__ = ["ImageNetDatasetMapper"] 21 | 22 | 23 | # This is specifically designed for the COCO dataset. 24 | class ImageNetDatasetMapper: 25 | """ 26 | A callable which takes a dataset dict in Detectron2 Dataset format, 27 | and map it into a format used by MaskFormer. 28 | 29 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 30 | 31 | The callable currently does the following: 32 | 33 | 1. Read the image from "file_name" 34 | 2. Applies geometric transforms to the image and annotation 35 | 3. Find and applies suitable cropping to the image and annotation 36 | 4. Prepare image and annotation to Tensors 37 | """ 38 | 39 | @configurable 40 | def __init__( 41 | self, 42 | is_train=True, 43 | size_train=None, 44 | size_test=None, 45 | size_crop=None, 46 | ): 47 | """ 48 | NOTE: this interface is experimental. 49 | Args: 50 | is_train: for training or inference 51 | augmentations: a list of augmentations or deterministic transforms to apply 52 | tfm_gens: data augmentation 53 | image_format: an image format supported by :func:`detection_utils.read_image`. 54 | """ 55 | self.is_train = is_train 56 | self.size_train = size_train 57 | self.size_test = size_test 58 | self.size_crop = size_crop 59 | 60 | t = [] 61 | t.append(transforms.Resize(size_crop, interpolation=Image.BICUBIC)) 62 | t.append(transforms.CenterCrop(size_test)) 63 | self.transform = transforms.Compose(t) 64 | 65 | @classmethod 66 | def from_config(cls, cfg, is_train=True): 67 | ret = { 68 | "is_train": is_train, 69 | "size_train": cfg['INPUT']['SIZE_TRAIN'], 70 | "size_test": cfg['INPUT']['SIZE_TEST'], 71 | "size_crop": cfg['INPUT']['SIZE_CROP'] 72 | } 73 | return ret 74 | 75 | def __call__(self, dataset_dict): 76 | """ 77 | Args: 78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 79 | 80 | Returns: 81 | dict: a format that builtin models in detectron2 accept 82 | """ 83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 84 | file_name = dataset_dict['file_name'] 85 | image = Image.open(file_name).convert('RGB') 86 | 87 | if self.is_train == False: 88 | image = self.transform(image) 89 | image = torch.from_numpy(np.asarray(image).copy()) 90 | image = image.permute(2,0,1) 91 | 92 | dataset_dict['image'] = image 93 | dataset_dict['height'] = image.shape[1] 94 | dataset_dict['width'] = image.shape[2] 95 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/scannet_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["ScanNetSegDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class ScanNetSegDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def read_semseg(self, file_name): 76 | if '.png' in file_name: 77 | semseg = np.asarray(Image.open(file_name)) 78 | elif '.mat' in file_name: 79 | semseg = scipy.io.loadmat(file_name)['LabelMap'] 80 | return semseg 81 | 82 | def __call__(self, dataset_dict): 83 | """ 84 | Args: 85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 86 | 87 | Returns: 88 | dict: a format that builtin models in detectron2 accept 89 | """ 90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 91 | file_name = dataset_dict['file_name'] 92 | semseg_name = dataset_dict['sem_seg_file_name'] 93 | image = Image.open(file_name).convert('RGB') 94 | 95 | dataset_dict['width'] = image.size[0] 96 | dataset_dict['height'] = image.size[1] 97 | 98 | if self.is_train == False: 99 | image = self.transform(image) 100 | image = torch.from_numpy(np.asarray(image).copy()) 101 | image = image.permute(2,0,1) 102 | 103 | semseg = self.read_semseg(semseg_name) 104 | semseg = torch.from_numpy(semseg.astype(np.int32)) 105 | dataset_dict['image'] = image 106 | dataset_dict['semseg'] = semseg 107 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/scannet_pano_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["ScanNetPanoDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class ScanNetPanoDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def __call__(self, dataset_dict): 76 | """ 77 | Args: 78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 79 | 80 | Returns: 81 | dict: a format that builtin models in detectron2 accept 82 | """ 83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 84 | file_name = dataset_dict['file_name'] 85 | image = Image.open(file_name).convert('RGB') 86 | 87 | dataset_dict['file_name'] = '_'.join(file_name.split('/')[-3:]) # HACK for /tmp file storage on predictions. 88 | dataset_dict['width'] = image.size[0] 89 | dataset_dict['height'] = image.size[1] 90 | 91 | image = self.transform(image) 92 | image = torch.from_numpy(np.asarray(image).copy()) 93 | image = image.permute(2,0,1) 94 | dataset_dict['image'] = image 95 | return dataset_dict -------------------------------------------------------------------------------- /datasets/dataset_mappers/sunrgbd_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | import copy 9 | 10 | import scipy.io 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from modeling.utils import configurable 17 | 18 | __all__ = ["SunRGBDSegDatasetMapper"] 19 | 20 | 21 | # This is specifically designed for the COCO dataset. 22 | class SunRGBDSegDatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by MaskFormer. 26 | 27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 28 | 29 | The callable currently does the following: 30 | 31 | 1. Read the image from "file_name" 32 | 2. Applies geometric transforms to the image and annotation 33 | 3. Find and applies suitable cropping to the image and annotation 34 | 4. Prepare image and annotation to Tensors 35 | """ 36 | 37 | @configurable 38 | def __init__( 39 | self, 40 | is_train=True, 41 | min_size_test=None, 42 | max_size_test=None, 43 | mean=None, 44 | std=None, 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: for training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | tfm_gens: data augmentation 52 | image_format: an image format supported by :func:`detection_utils.read_image`. 53 | """ 54 | self.is_train = is_train 55 | self.min_size_test = min_size_test 56 | self.max_size_test = max_size_test 57 | self.pixel_mean = torch.tensor(mean)[:,None,None] 58 | self.pixel_std = torch.tensor(std)[:,None,None] 59 | 60 | t = [] 61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC)) 62 | self.transform = transforms.Compose(t) 63 | 64 | @classmethod 65 | def from_config(cls, cfg, is_train=True): 66 | ret = { 67 | "is_train": is_train, 68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'], 69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'], 70 | "mean": cfg['INPUT']['PIXEL_MEAN'], 71 | "std": cfg['INPUT']['PIXEL_STD'], 72 | } 73 | return ret 74 | 75 | def read_semseg(self, file_name): 76 | if '.png' in file_name: 77 | semseg = np.asarray(Image.open(file_name)) 78 | elif '.mat' in file_name: 79 | semseg = scipy.io.loadmat(file_name)['LabelMap'] 80 | return semseg 81 | 82 | def __call__(self, dataset_dict): 83 | """ 84 | Args: 85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 86 | 87 | Returns: 88 | dict: a format that builtin models in detectron2 accept 89 | """ 90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 91 | file_name = dataset_dict['file_name'] 92 | semseg_name = dataset_dict['sem_seg_file_name'] 93 | image = Image.open(file_name).convert('RGB') 94 | 95 | dataset_dict['width'] = image.size[0] 96 | dataset_dict['height'] = image.size[1] 97 | 98 | if self.is_train == False: 99 | image = self.transform(image) 100 | image = torch.from_numpy(np.asarray(image).copy()) 101 | image = image.permute(2,0,1) 102 | 103 | semseg = self.read_semseg(semseg_name) 104 | semseg = torch.from_numpy(semseg.astype(np.int32)) 105 | dataset_dict['image'] = image 106 | dataset_dict['semseg'] = semseg 107 | return dataset_dict -------------------------------------------------------------------------------- /datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .instance_evaluation import * 2 | from .classification_evaluation import * 3 | from .segmentation_evaluation import * 4 | from .retrieval_evaluation import * 5 | from .captioning_evaluation import * 6 | from .panoptic_evaluation import * 7 | from .grounding_evaluation import * 8 | from .interactive_evaluation import * -------------------------------------------------------------------------------- /datasets/evaluation/classification_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import logging 11 | 12 | from detectron2.evaluation.evaluator import DatasetEvaluator 13 | 14 | from utils.misc import AverageMeter 15 | from utils.distributed import get_world_size 16 | 17 | 18 | @torch.no_grad() 19 | def accuracy(output, target, topk=(1,)): 20 | """Computes the precision@k for the specified values of k""" 21 | if isinstance(output, list): 22 | output = output[-1] 23 | 24 | n_classes = output.size()[1] 25 | maxk = min(max(topk), n_classes) 26 | batch_size = target.size(0) 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 34 | res.append(correct_k.mul_(100.0 / batch_size).item()) 35 | return res 36 | 37 | class ClassificationEvaluator(DatasetEvaluator): 38 | def __init__(self, *args): 39 | self.top1 = AverageMeter() 40 | self.top5 = AverageMeter() 41 | self._logger = logging.getLogger(__name__) 42 | 43 | def reset(self): 44 | self.top1.reset() 45 | self.top5.reset() 46 | 47 | def process(self, inputs, outputs): 48 | logits = torch.stack([o['pred_class'] for o in outputs]) 49 | y = torch.tensor([t['class_id'] for t in inputs], device=logits.device) 50 | prec1, prec5 = accuracy(logits, y, (1, 5)) 51 | self.top1.update(prec1, y.size(0)) 52 | self.top5.update(prec5, y.size(0)) 53 | 54 | def evaluate(self): 55 | if get_world_size() > 1: 56 | tmp_tensor = torch.tensor( 57 | [self.top1.sum, self.top5.sum, self.top1.count], 58 | device=torch.cuda.current_device() 59 | ) 60 | torch.distributed.all_reduce( 61 | tmp_tensor, torch.distributed.ReduceOp.SUM 62 | ) 63 | top1_sum, top5_sum, count = tmp_tensor.tolist() 64 | else: 65 | top1_sum = self.top1.sum 66 | top5_sum = self.top5.sum 67 | count = self.top1.count 68 | 69 | results = {} 70 | scores = { 71 | 'top1': top1_sum / count, 72 | "top5": top5_sum / count 73 | } 74 | results['class'] = scores 75 | self._logger.info(results) 76 | return results 77 | -------------------------------------------------------------------------------- /datasets/evaluation/grounding_evaluation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import logging 8 | import torch 9 | from torchvision.ops import box_iou 10 | 11 | from detectron2.structures import BoxMode 12 | from detectron2.data import MetadataCatalog 13 | from detectron2.utils.comm import all_gather, is_main_process, synchronize 14 | from detectron2.evaluation.evaluator import DatasetEvaluator 15 | 16 | 17 | class GroundingEvaluator(DatasetEvaluator): 18 | """ 19 | Evaluate grounding segmentation metrics. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | dataset_name, 25 | compute_box=False, 26 | distributed=True, 27 | ): 28 | self._logger = logging.getLogger(__name__) 29 | self._dataset_name = dataset_name 30 | self._distributed = distributed 31 | self._cpu_device = torch.device("cpu") 32 | self._compute_box = compute_box 33 | meta = MetadataCatalog.get(dataset_name) 34 | 35 | def reset(self): 36 | self.cum_I = 0 37 | self.cum_U = 0 38 | self.mIoU = 0 39 | self.eval_seg_iou_list = [.5, .6, .7, .8, .9] 40 | self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device) 41 | self.seg_total = 0 42 | if self._compute_box: 43 | self.mIoU_box = 0 44 | self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device) 45 | 46 | @staticmethod 47 | def computeIoU(pred_seg, gd_seg): 48 | I = (pred_seg & gd_seg) 49 | U = (pred_seg | gd_seg) 50 | return I, U 51 | 52 | def process(self, inputs, outputs): 53 | for input, output in zip(inputs, outputs): 54 | pred = output['grounding_mask'].sigmoid() > 0.5 55 | gt = input['groundings']['masks'].bool() 56 | bsi = len(pred) 57 | I, U = self.computeIoU(pred, gt) 58 | self.cum_I += I.sum().cpu() 59 | self.cum_U += U.sum().cpu() 60 | IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6) 61 | self.mIoU += IoU.sum().cpu() 62 | 63 | if self._compute_box: 64 | pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) 65 | gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu() 66 | IoU_box = box_iou(pred_box, gt_box).diagonal() 67 | self.mIoU_box += IoU_box.sum() 68 | 69 | for idx in range(len(self.eval_seg_iou_list)): 70 | eval_seg_iou = self.eval_seg_iou_list[idx] 71 | self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu() 72 | if self._compute_box: 73 | self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu() 74 | self.seg_total += bsi 75 | 76 | def evaluate(self): 77 | if self._distributed: 78 | synchronize() 79 | self.cum_I = torch.stack(all_gather(self.cum_I)).sum() 80 | self.cum_U = torch.stack(all_gather(self.cum_U)).sum() 81 | self.mIoU = torch.stack(all_gather(self.mIoU)).sum() 82 | self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0) 83 | self.seg_total = sum(all_gather(self.seg_total)) 84 | 85 | if self._compute_box: 86 | self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum() 87 | self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0) 88 | if not is_main_process(): 89 | return 90 | 91 | results = {} 92 | for idx in range(len(self.eval_seg_iou_list)): 93 | result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx]) 94 | results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item() 95 | results['cIoU'] = (self.cum_I*100./self.cum_U).item() 96 | results['mIoU'] = (self.mIoU*100./self.seg_total).item() 97 | 98 | if self._compute_box: 99 | for idx in range(len(self.eval_seg_iou_list)): 100 | result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx]) 101 | results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item() 102 | results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item() 103 | 104 | self._logger.info(results) 105 | return {'grounding': results} -------------------------------------------------------------------------------- /datasets/evaluation/instance_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import contextlib 3 | import copy 4 | import io 5 | import itertools 6 | import json 7 | import logging 8 | import numpy as np 9 | import os 10 | import pickle 11 | from collections import OrderedDict 12 | import pycocotools.mask as mask_util 13 | import torch 14 | from pycocotools.coco import COCO 15 | from pycocotools.cocoeval import COCOeval 16 | from tabulate import tabulate 17 | 18 | import detectron2.utils.comm as comm 19 | from detectron2.config import CfgNode 20 | from detectron2.data import MetadataCatalog 21 | from detectron2.data.datasets.coco import convert_to_coco_json 22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco 23 | from detectron2.evaluation.fast_eval_api import COCOeval_opt 24 | from detectron2.structures import Boxes, BoxMode, pairwise_iou 25 | from detectron2.utils.file_io import PathManager 26 | from detectron2.utils.logger import create_small_table 27 | 28 | 29 | # modified from COCOEvaluator for instance segmetnat 30 | class InstanceSegEvaluator(COCOEvaluator): 31 | """ 32 | Evaluate AR for object proposals, AP for instance detection/segmentation, AP 33 | for keypoint detection outputs using COCO's metrics. 34 | See http://cocodataset.org/#detection-eval and 35 | http://cocodataset.org/#keypoints-eval to understand its metrics. 36 | The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means 37 | the metric cannot be computed (e.g. due to no predictions made). 38 | 39 | In addition to COCO, this evaluator is able to support any bounding box detection, 40 | instance segmentation, or keypoint detection dataset. 41 | """ 42 | 43 | def _eval_predictions(self, predictions, img_ids=None): 44 | """ 45 | Evaluate predictions. Fill self._results with the metrics of the tasks. 46 | """ 47 | self._logger.info("Preparing results for COCO format ...") 48 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) 49 | tasks = self._tasks or self._tasks_from_predictions(coco_results) 50 | 51 | # unmap the category ids for COCO 52 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): 53 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id 54 | # all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) 55 | # num_classes = len(all_contiguous_ids) 56 | # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 57 | 58 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} 59 | for result in coco_results: 60 | category_id = result["category_id"] 61 | # assert category_id < num_classes, ( 62 | # f"A prediction has class={category_id}, " 63 | # f"but the dataset only has {num_classes} classes and " 64 | # f"predicted class id should be in [0, {num_classes - 1}]." 65 | # ) 66 | assert category_id in reverse_id_mapping, ( 67 | f"A prediction has class={category_id}, " 68 | f"but the dataset only has class ids in {dataset_id_to_contiguous_id}." 69 | ) 70 | result["category_id"] = reverse_id_mapping[category_id] 71 | 72 | if self._output_dir: 73 | file_path = os.path.join(self._output_dir, "coco_instances_results.json") 74 | self._logger.info("Saving results to {}".format(file_path)) 75 | with PathManager.open(file_path, "w") as f: 76 | f.write(json.dumps(coco_results)) 77 | f.flush() 78 | 79 | if not self._do_evaluation: 80 | self._logger.info("Annotations are not available for evaluation.") 81 | return 82 | 83 | self._logger.info( 84 | "Evaluating predictions with {} COCO API...".format( 85 | "unofficial" if self._use_fast_impl else "official" 86 | ) 87 | ) 88 | for task in sorted(tasks): 89 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" 90 | coco_eval = ( 91 | _evaluate_predictions_on_coco( 92 | self._coco_api, 93 | coco_results, 94 | task, 95 | kpt_oks_sigmas=self._kpt_oks_sigmas, 96 | use_fast_impl=self._use_fast_impl, 97 | img_ids=img_ids, 98 | max_dets_per_image=self._max_dets_per_image, 99 | ) 100 | if len(coco_results) > 0 101 | else None # cocoapi does not handle empty results very well 102 | ) 103 | 104 | res = self._derive_coco_results( 105 | coco_eval, task, class_names=self._metadata.get("thing_classes") 106 | ) 107 | self._results[task] = res 108 | -------------------------------------------------------------------------------- /datasets/evaluation/interactive_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision.ops import box_iou 8 | 9 | from detectron2.structures import BoxMode 10 | from detectron2.data import MetadataCatalog 11 | from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize 12 | from detectron2.evaluation.evaluator import DatasetEvaluator 13 | 14 | 15 | class InteractiveEvaluator(DatasetEvaluator): 16 | """ 17 | Evaluate point interactive IoU metrics. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dataset_name, 23 | output_dir, 24 | max_clicks=20, 25 | iou_iter=1, 26 | compute_box=False, 27 | distributed=True, 28 | ): 29 | self._logger = logging.getLogger(__name__) 30 | self._dataset_name = dataset_name 31 | self._distributed = distributed 32 | self._cpu_device = torch.device("cpu") 33 | self._output_dir = output_dir 34 | 35 | self.max_clicks = max_clicks 36 | self.iou_iter = iou_iter 37 | meta = MetadataCatalog.get(dataset_name) 38 | 39 | def reset(self): 40 | self.iou_list = [] 41 | self.num_samples = 0 42 | self.all_ious = [0.5, 0.8, 0.85, 0.9] 43 | 44 | def process(self, inputs, outputs): 45 | self.iou_list += [o['mask_iou'] for o in outputs] 46 | self.num_samples += len(outputs) 47 | 48 | def compute_noc(self): 49 | def _get_noc(iou_arr, iou_thr): 50 | vals = iou_arr >= iou_thr 51 | return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks 52 | 53 | noc_list = {} 54 | for iou_thr in self.all_ious: 55 | scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list] 56 | noc_list[str(iou_thr)] = scores_arr 57 | 58 | iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1] 59 | noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()} 60 | 61 | if self._distributed: 62 | num_samples = sum(all_gather(self.num_samples)) 63 | noc_list_sum_gather = all_gather(noc_list_sum) 64 | iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu()) 65 | 66 | noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]} 67 | for nlg in noc_list_sum_gather: 68 | for key, value in nlg.items(): 69 | noc_list_sum[key] += value 70 | 71 | pred_noc = {} 72 | if self._distributed and (not is_main_process()): 73 | return pred_noc 74 | 75 | for key, value in noc_list_sum.items(): 76 | pred_noc[key] = value / num_samples 77 | 78 | pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples 79 | return pred_noc 80 | 81 | def evaluate(self): 82 | pred_noc = self.compute_noc() 83 | 84 | if self._distributed and (not is_main_process()): 85 | return 86 | 87 | def draw_iou_curve(iou_list, save_dir): 88 | iou_list = torch.stack(iou_list, dim=0) 89 | iou_list = iou_list.mean(dim=0).cpu().numpy() 90 | # draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib 91 | import matplotlib.pyplot as plt 92 | plt.figure() 93 | plt.plot(range(1, self.max_clicks+1), iou_list) 94 | plt.xlabel('Number of clicks') 95 | plt.ylabel('IoU') 96 | 97 | 98 | # create directory if not exist 99 | import os 100 | output_dir = os.path.join(save_dir, 'iou_by_clicks') 101 | if not os.path.exists(output_dir): 102 | os.makedirs(output_dir) 103 | 104 | # get current time and format in 10 digits 105 | import time 106 | current_time = time.time() 107 | current_time = int(current_time) 108 | current_time = str(current_time) 109 | 110 | # save iou curve 111 | plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time))) 112 | 113 | draw_iou_curve(self.iou_list, self._output_dir) 114 | results = {} 115 | for idx in range(len(self.all_ious)): 116 | result_str = 'noc@{}'.format(self.all_ious[idx]) 117 | results[result_str] = pred_noc[str(self.all_ious[idx])] 118 | 119 | results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter'] 120 | 121 | self._logger.info(results) 122 | return {'interactive': results} -------------------------------------------------------------------------------- /datasets/registration/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | register_refcoco_dataset, 3 | register_ade20k_full, 4 | register_ade20k_panoptic, 5 | register_coco_stuff_10k, 6 | register_coco_panoptic_annos_semseg, 7 | register_coco_panoptic_annos_caption, 8 | register_coco_panoptic_annos_caption_grounding, 9 | register_coco_lvis_panoptic_annos_caption_grounding, 10 | register_ade20k_instance, 11 | register_vlp_datasets, 12 | register_sunrgbd_semseg, 13 | register_scannet_semseg, 14 | register_bdd100k_semseg, 15 | register_scannet_panoptic, 16 | register_bdd100k_panoseg, 17 | register_pascalvoc_eval, 18 | ) -------------------------------------------------------------------------------- /datasets/registration/register_ade20k_instance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import json 3 | import logging 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | 8 | from detectron2.data import DatasetCatalog, MetadataCatalog 9 | from detectron2.data.datasets.coco import load_coco_json, register_coco_instances 10 | from detectron2.utils.file_io import PathManager 11 | 12 | ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}] 13 | 14 | 15 | _PREDEFINED_SPLITS = { 16 | # point annotations without masks 17 | "ade20k_instance_train": ( 18 | "ADEChallengeData2016/images/training", 19 | "ADEChallengeData2016/ade20k_instance_train.json", 20 | ), 21 | "ade20k_instance_val": ( 22 | "ADEChallengeData2016/images/validation", 23 | "ADEChallengeData2016/ade20k_instance_val.json", 24 | ), 25 | } 26 | 27 | 28 | def _get_ade_instances_meta(): 29 | thing_ids = [k["id"] for k in ADE_CATEGORIES] 30 | assert len(thing_ids) == 100, len(thing_ids) 31 | # Mapping from the incontiguous ADE category id to an id in [0, 99] 32 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 33 | thing_classes = [k["name"] for k in ADE_CATEGORIES] 34 | ret = { 35 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 36 | "thing_classes": thing_classes, 37 | } 38 | return ret 39 | 40 | 41 | def register_all_ade20k_instance(root): 42 | for key, (image_root, json_file) in _PREDEFINED_SPLITS.items(): 43 | # Assume pre-defined datasets live in `./datasets`. 44 | register_coco_instances( 45 | key, 46 | _get_ade_instances_meta(), 47 | os.path.join(root, json_file) if "://" not in json_file else json_file, 48 | os.path.join(root, image_root), 49 | ) 50 | 51 | 52 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 53 | register_all_ade20k_instance(_root) 54 | -------------------------------------------------------------------------------- /datasets/registration/register_bdd100k_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import BDD_SEM 17 | 18 | __all__ = ["load_scannet_instances", "register_scannet_context"] 19 | 20 | 21 | def load_bdd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load BDD annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | img_folder = os.path.join(dirname, 'images', '10k', split) 31 | img_pths = sorted(glob.glob(os.path.join(img_folder, '*.jpg'))) 32 | 33 | sem_folder = os.path.join(dirname, 'labels', 'sem_seg', 'masks', split) 34 | sem_pths = sorted(glob.glob(os.path.join(sem_folder, '*.png'))) 35 | 36 | assert len(img_pths) == len(sem_pths) 37 | 38 | dicts = [] 39 | for img_pth, sem_pth in zip(img_pths, sem_pths): 40 | r = { 41 | "file_name": img_pth, 42 | "sem_seg_file_name": sem_pth, 43 | "image_id": img_pth.split('/')[-1].split('.')[0], 44 | } 45 | dicts.append(r) 46 | return dicts 47 | 48 | 49 | def register_bdd_context(name, dirname, split, class_names=BDD_SEM): 50 | DatasetCatalog.register(name, lambda: load_bdd_instances(name, dirname, split, class_names)) 51 | MetadataCatalog.get(name).set( 52 | stuff_classes=class_names, 53 | dirname=dirname, 54 | split=split, 55 | ignore_label=[255], 56 | thing_dataset_id_to_contiguous_id={}, 57 | class_offset=0, 58 | keep_sem_bgd=False 59 | ) 60 | 61 | 62 | def register_all_sunrgbd_seg(root): 63 | SPLITS = [ 64 | ("bdd10k_val_sem_seg", "bdd100k", "val"), 65 | ] 66 | 67 | for name, dirname, split in SPLITS: 68 | register_bdd_context(name, os.path.join(root, dirname), split) 69 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 70 | 71 | 72 | _root = os.getenv("DATASET", "datasets") 73 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_imagenet_cls.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import glob 10 | from typing import List, Tuple, Union 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import IMAGENET_CLASSES, IMAGENET_FOLDER_NAMES 17 | 18 | __all__ = ["load_imagenet_images", "register_imagenet"] 19 | 20 | 21 | def load_imagenet_images(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load ImageNet annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | image_folders = sorted(glob.glob(os.path.join(dirname, split, 'n*'))) 31 | 32 | dicts = [] 33 | for image_folder in image_folders: 34 | folder_name = image_folder.split('/')[-1] 35 | image_pths = sorted(glob.glob(os.path.join(image_folder, "*.JPEG"))) 36 | for img_pth in image_pths: 37 | r = { 38 | "file_name": img_pth, 39 | "class_name": IMAGENET_CLASSES[IMAGENET_FOLDER_NAMES.index(folder_name)], 40 | "class_id": IMAGENET_FOLDER_NAMES.index(folder_name), 41 | } 42 | dicts.append(r) 43 | return dicts 44 | 45 | 46 | def register_imagenet(name, dirname, split, year, class_names=IMAGENET_CLASSES): 47 | DatasetCatalog.register(name, lambda: load_imagenet_images(dirname, split, class_names)) 48 | MetadataCatalog.get(name).set( 49 | thing_classes=list(class_names), dirname=dirname, year=year, split=split 50 | ) 51 | 52 | 53 | def register_all_imagenet(root): 54 | SPLITS = [ 55 | ("imagenet_val", "imagenet", "val", "2012"), 56 | ] 57 | for name, dirname, split, year in SPLITS: 58 | register_imagenet(name, os.path.join(root, dirname), split, year) 59 | MetadataCatalog.get(name).evaluator_type = "classification" 60 | 61 | 62 | _root = os.getenv("DATASET", "datasets") 63 | register_all_imagenet(_root) -------------------------------------------------------------------------------- /datasets/registration/register_pascalvoc_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | import os 4 | import glob 5 | from typing import List, Tuple, Union 6 | import xml.etree.ElementTree as ET 7 | 8 | import cv2 9 | import numpy as np 10 | from scipy.io import loadmat 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | 17 | __all__ = ["load_pascalvoc_instances", "register_pascalvoc_context"] 18 | 19 | def get_labels_with_sizes(x): 20 | obj_sizes = np.bincount(x.flatten()) 21 | labels = np.nonzero(obj_sizes)[0].tolist() 22 | labels = [x for x in labels if x != 0] 23 | return labels, obj_sizes[labels].tolist() 24 | 25 | def load_pascalvoc_instances(name: str, dirname: str, mode: str, split: str): 26 | """ 27 | Load Pascal VOC detection annotations to Detectron2 format. 28 | 29 | Args: 30 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 31 | split (str): one of "train", "test", "val", "trainval" 32 | class_names: list or tuple of class names 33 | """ 34 | with PathManager.open(os.path.join(dirname, 'ImageSets', 'Segmentation', split + ".txt")) as f: 35 | fileids = np.loadtxt(f, dtype=np.str) 36 | 37 | dicts = [] 38 | for field in fileids: 39 | anno_path = os.path.join(dirname, "Annotations", "{}.xml".format(field)) 40 | image_path = os.path.join(dirname, "JPEGImages", "{}.jpg".format(field)) 41 | inst_path = os.path.join(dirname, "SegmentationObject", "{}.png".format(field)) 42 | semseg_path = os.path.join(dirname, "SegmentationClass", "{}.png".format(field)) 43 | 44 | instances_mask = cv2.imread(inst_path) 45 | instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) 46 | 47 | objects_ids = np.unique(instances_mask) 48 | objects_ids = [x for x in objects_ids if x != 0 and x != 220] 49 | 50 | slice_size = 5 51 | for i in range(0, len(objects_ids), slice_size): 52 | r = { 53 | "file_name": image_path, 54 | "inst_name": inst_path, 55 | "semseg_name": semseg_path, 56 | "objects_ids": objects_ids[i:i+slice_size], 57 | } 58 | dicts.append(r) 59 | return dicts 60 | 61 | def register_pascalvoc_context(name, dirname, mode, split): 62 | DatasetCatalog.register("{}_{}".format(name, mode), lambda: load_pascalvoc_instances(name, dirname, mode, split)) 63 | MetadataCatalog.get("{}_{}".format(name, mode)).set( 64 | dirname=dirname, 65 | thing_dataset_id_to_contiguous_id={}, 66 | ) 67 | 68 | def register_all_sbd(root): 69 | SPLITS = [ 70 | ("pascalvoc_val", "PascalVOC", "Point", "val"), 71 | ("pascalvoc_val", "PascalVOC", "Scribble", "val"), 72 | ("pascalvoc_val", "PascalVOC", "Polygon", "val"), 73 | ("pascalvoc_val", "PascalVOC", "Circle", "val"), 74 | ("pascalvoc_val", "PascalVOC", "Box", "val"), 75 | ] 76 | 77 | for name, dirname, mode, split in SPLITS: 78 | register_pascalvoc_context(name, os.path.join(root, dirname), mode, split) 79 | MetadataCatalog.get("{}_{}".format(name, mode)).evaluator_type = "interactive" 80 | 81 | _root = os.getenv("DATASET", "datasets") 82 | register_all_sbd(_root) -------------------------------------------------------------------------------- /datasets/registration/register_refcoco_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import json 8 | import os 9 | import collections 10 | 11 | from detectron2.data import DatasetCatalog, MetadataCatalog 12 | from detectron2.data.datasets import load_sem_seg 13 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 14 | from detectron2.utils.file_io import PathManager 15 | 16 | 17 | _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = { 18 | # "refcocog_train_umd": ( 19 | # "coco/train2017", # image_root 20 | # "coco/annotations/refcocog_umd_train.json", # annot_root 21 | # ), 22 | # "refcocog_val_google": ( 23 | # "coco/train2017", # image_root 24 | # "coco/annotations/refcocog_google.json", # annot_root 25 | # ), 26 | # "refcocop_val_unc": ( 27 | # "coco/train2017", # image_root 28 | # "coco/annotations/refcocop_unc.json", # annot_root 29 | # ), 30 | # "refcoco_val_unc": ( 31 | # "coco/train2017", # image_root 32 | # "coco/annotations/refcoco_unc.json", # annot_root 33 | # ), 34 | "refcocog_val_umd": ( 35 | "coco/train2017", # image_root 36 | "coco/annotations/refcocog_umd_val.json", # annot_root 37 | ), 38 | } 39 | 40 | 41 | def get_metadata(): 42 | meta = {} 43 | return meta 44 | 45 | 46 | def load_refcoco_json(image_root, annot_json, metadata): 47 | """ 48 | Args: 49 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". 50 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". 51 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". 52 | Returns: 53 | list[dict]: a list of dicts in Detectron2 standard format. (See 54 | `Using Custom Datasets `_ ) 55 | """ 56 | 57 | with PathManager.open(annot_json) as f: 58 | json_info = json.load(f) 59 | 60 | # build dictionary for grounding 61 | grd_dict = collections.defaultdict(list) 62 | for grd_ann in json_info['annotations']: 63 | image_id = int(grd_ann["image_id"]) 64 | grd_dict[image_id].append(grd_ann) 65 | 66 | ret = [] 67 | for image in json_info["images"]: 68 | image_id = int(image["id"]) 69 | image_file = os.path.join(image_root, image['file_name']) 70 | grounding_anno = grd_dict[image_id] 71 | ret.append( 72 | { 73 | "file_name": image_file, 74 | "image_id": image_id, 75 | "grounding_info": grounding_anno, 76 | } 77 | ) 78 | assert len(ret), f"No images found in {image_root}!" 79 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] 80 | return ret 81 | 82 | 83 | def register_refcoco( 84 | name, metadata, image_root, annot_json): 85 | DatasetCatalog.register( 86 | name, 87 | lambda: load_refcoco_json(image_root, annot_json, metadata), 88 | ) 89 | MetadataCatalog.get(name).set( 90 | image_root=image_root, 91 | json_file=annot_json, 92 | evaluator_type="grounding_refcoco", 93 | ignore_label=255, 94 | label_divisor=1000, 95 | **metadata, 96 | ) 97 | 98 | 99 | def register_all_refcoco(root): 100 | for ( 101 | prefix, 102 | (image_root, annot_root), 103 | ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items(): 104 | register_refcoco( 105 | prefix, 106 | get_metadata(), 107 | os.path.join(root, image_root), 108 | os.path.join(root, annot_root), 109 | ) 110 | 111 | 112 | _root = os.getenv("DATASET", "datasets") 113 | register_all_refcoco(_root) 114 | -------------------------------------------------------------------------------- /datasets/registration/register_scannet_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import os 9 | import glob 10 | from typing import List, Tuple, Union 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from utils.constants import SCAN_37, SCAN_40, SCAN_20 17 | 18 | __all__ = ["load_scannet_instances", "register_scannet_context"] 19 | 20 | name2folder = {"scannet_41_val_seg": "label41", 21 | "scannet_38_val_seg": "label38", 22 | "scannet_21_val_seg": "label21",} 23 | 24 | name2class = {"scannet_41_val_seg": SCAN_40, 25 | "scannet_38_val_seg": SCAN_37, 26 | "scannet_21_val_seg": SCAN_20} 27 | 28 | 29 | def load_scannet_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 30 | """ 31 | Load ScanNet annotations to Detectron2 format. 32 | 33 | Args: 34 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 35 | split (str): one of "train", "test", "val", "trainval" 36 | class_names: list or tuple of class names 37 | """ 38 | with PathManager.open(os.path.join(dirname, "meta", split + ".txt")) as f: 39 | fileids = np.loadtxt(f, dtype=np.str) 40 | 41 | dicts = [] 42 | for field in fileids: 43 | image_dir = os.path.join(dirname, 'images', field[0]) 44 | semseg_dir = image_dir.replace('color', name2folder[name]).replace('jpg', 'png') 45 | r = { 46 | "file_name": image_dir, 47 | "sem_seg_file_name": semseg_dir, 48 | "image_id": semseg_dir.split('/')[-3] + semseg_dir.split('/')[-1].split('.')[0], 49 | } 50 | dicts.append(r) 51 | return dicts 52 | 53 | 54 | def register_scannet_context(name, dirname, split, class_names=name2class): 55 | DatasetCatalog.register(name, lambda: load_scannet_instances(name, dirname, split, class_names)) 56 | MetadataCatalog.get(name).set( 57 | stuff_classes=class_names[name], 58 | dirname=dirname, 59 | split=split, 60 | ignore_label=[0], 61 | thing_dataset_id_to_contiguous_id={}, 62 | class_offset=1, 63 | keep_sem_bgd=False 64 | ) 65 | 66 | 67 | def register_all_sunrgbd_seg(root): 68 | SPLITS = [ 69 | ("scannet_41_val_seg", "scannet_frames_25k", "val"), 70 | ("scannet_38_val_seg", "scannet_frames_25k", "val"), 71 | ("scannet_21_val_seg", "scannet_frames_25k", "val"), 72 | ] 73 | 74 | for name, dirname, split in SPLITS: 75 | register_scannet_context(name, os.path.join(root, dirname), split) 76 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 77 | 78 | 79 | _root = os.getenv("DATASET", "datasets") 80 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/registration/register_sunrgbd_semseg.py: -------------------------------------------------------------------------------- 1 | 2 | # -------------------------------------------------------- 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 7 | # -------------------------------------------------------- 8 | import numpy as np 9 | import os 10 | import glob 11 | from typing import List, Tuple, Union 12 | 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.structures import BoxMode 15 | from detectron2.utils.file_io import PathManager 16 | 17 | from utils.constants import SUN_RGBD_37 18 | 19 | __all__ = ["load_sunrgbd_instances", "register_sunrgbd_context"] 20 | 21 | def load_sunrgbd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 22 | """ 23 | Load SUN-RGBD detection annotations to Detectron2 format. 24 | 25 | Args: 26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 27 | split (str): one of "train", "test", "val", "trainval" 28 | class_names: list or tuple of class names 29 | """ 30 | if split == 'val': 31 | split = 'test' 32 | 33 | # Needs to read many small annotation files. Makes sense at local 34 | image_pths = sorted(glob.glob(os.path.join(dirname, 'image', split, '*.jpg'))) 35 | semseg_pths = sorted(glob.glob(os.path.join(dirname, 'label37', split, '*.png'))) 36 | 37 | assert len(image_pths) == len(semseg_pths) 38 | 39 | dicts = [] 40 | for image_dir, semseg_dir in zip(image_pths, semseg_pths): 41 | r = { 42 | "file_name": image_dir, 43 | "sem_seg_file_name": semseg_dir, 44 | "image_id": semseg_dir.split('/')[-1].split('.')[0], 45 | } 46 | dicts.append(r) 47 | return dicts 48 | 49 | 50 | def register_sun_context(name, dirname, split, class_names=SUN_RGBD_37): 51 | DatasetCatalog.register(name, lambda: load_sunrgbd_instances(name, dirname, split, class_names)) 52 | MetadataCatalog.get(name).set( 53 | stuff_classes=class_names, 54 | dirname=dirname, 55 | split=split, 56 | ignore_label=[0], 57 | thing_dataset_id_to_contiguous_id={}, 58 | class_offset=1, 59 | keep_sem_bgd=False 60 | ) 61 | 62 | 63 | def register_all_sunrgbd_seg(root): 64 | SPLITS = [ 65 | ("sunrgbd_37_val_seg", "sun_rgbd", "val"), 66 | ] 67 | 68 | for name, dirname, split in SPLITS: 69 | register_sun_context(name, os.path.join(root, dirname), split) 70 | MetadataCatalog.get(name).evaluator_type = "sem_seg" 71 | 72 | 73 | _root = os.getenv("DATASET", "datasets") 74 | register_all_sunrgbd_seg(_root) -------------------------------------------------------------------------------- /datasets/semseg_loader.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import scipy.io 3 | import numpy as np 4 | 5 | def load_semseg(filename, loader_type): 6 | if loader_type == 'PIL': 7 | semseg = np.array(Image.open(filename), dtype=np.int) 8 | elif loader_type == 'MAT': 9 | semseg = scipy.io.loadmat(filename)['LabelMap'] 10 | return semseg -------------------------------------------------------------------------------- /datasets/utils/refcoco2json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from refer import REFER 4 | 5 | coco_root = '/pth/to/coco' 6 | ref_root = '/pth/to/refcocoseg' 7 | 8 | coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json'))) 9 | coco_train_id = [] 10 | image_annot = {} 11 | for i in range(len(coco_train_annot['images'])): 12 | coco_train_id.append(coco_train_annot['images'][i]['id']) 13 | image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i] 14 | 15 | refg = REFER(data_root=ref_root, 16 | dataset='refcocog', splitBy='umd') 17 | refg_val_ids = refg.getRefIds(split='val') 18 | 19 | full_anno = [] 20 | for ref_id in refg_val_ids: 21 | ref = refg.loadRefs(ref_id)[0] 22 | anno = refg.refToAnn[ref_id] 23 | anno.update(ref) 24 | full_anno.append(anno) 25 | 26 | imageid_list = [] 27 | final_anno = {} 28 | for anno in full_anno: 29 | imageid_list += [anno['image_id']] 30 | final_anno[anno['ann_id']] = anno 31 | 32 | annotations = [value for key, value in final_anno.items()] 33 | 34 | iamges = [] 35 | for image_id in list(set(imageid_list)): 36 | iamges += [image_annot[image_id]] 37 | 38 | outputs = {'images': iamges, 'annotations': annotations} 39 | print(len(iamges)) 40 | print(len(annotations)) 41 | json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w')) 42 | -------------------------------------------------------------------------------- /datasets/visual_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import ShapeSampler 2 | from .simpleclick_sampler import SimpleClickSampler 3 | 4 | 5 | def build_shape_sampler(cfg, **kwargs): 6 | sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE'] 7 | if sampler_name == 'random': 8 | return ShapeSampler(cfg, **kwargs) 9 | elif sampler_name in ['best', 'best_random']: 10 | return SimpleClickSampler(cfg, **kwargs) 11 | else: 12 | assert False, "not implemented" -------------------------------------------------------------------------------- /datasets/visual_sampler/circle.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from .mask_generators import get_mask_by_input_strokes 5 | 6 | class Circle: 7 | def __init__(self, cfg, is_train=True): 8 | self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES'] 9 | self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET'] 10 | self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB'] 11 | self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] 12 | self.is_train = is_train 13 | 14 | @staticmethod 15 | def get_stroke_preset(stroke_preset): 16 | if stroke_preset == 'object_like': 17 | return { 18 | "nVertexBound": [5, 30], 19 | "maxHeadSpeed": 15, 20 | "maxHeadAcceleration": (10, 1.5), 21 | "brushWidthBound": (20, 50), 22 | "nMovePointRatio": 0.5, 23 | "maxPiontMove": 10, 24 | "maxLineAcceleration": (5, 0.5), 25 | "boarderGap": None, 26 | "maxInitSpeed": 10, 27 | } 28 | elif stroke_preset == 'object_like_middle': 29 | return { 30 | "nVertexBound": [5, 15], 31 | "maxHeadSpeed": 8, 32 | "maxHeadAcceleration": (4, 1.5), 33 | "brushWidthBound": (20, 50), 34 | "nMovePointRatio": 0.5, 35 | "maxPiontMove": 5, 36 | "maxLineAcceleration": (5, 0.5), 37 | "boarderGap": None, 38 | "maxInitSpeed": 10, 39 | } 40 | elif stroke_preset == 'object_like_small': 41 | return { 42 | "nVertexBound": [5, 20], 43 | "maxHeadSpeed": 7, 44 | "maxHeadAcceleration": (3.5, 1.5), 45 | "brushWidthBound": (10, 30), 46 | "nMovePointRatio": 0.5, 47 | "maxPiontMove": 5, 48 | "maxLineAcceleration": (3, 0.5), 49 | "boarderGap": None, 50 | "maxInitSpeed": 4, 51 | } 52 | else: 53 | raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.') 54 | 55 | def get_random_points_from_mask(self, mask, n=5): 56 | h,w = mask.shape 57 | view_mask = mask.reshape(h*w) 58 | non_zero_idx = view_mask.nonzero()[:,0] 59 | selected_idx = torch.randperm(len(non_zero_idx))[:n] 60 | non_zero_idx = non_zero_idx[selected_idx] 61 | y = (non_zero_idx // w)*1.0 62 | x = (non_zero_idx % w)*1.0 63 | return torch.cat((x[:,None], y[:,None]), dim=1).numpy() 64 | 65 | def draw(self, mask=None, box=None): 66 | if mask.sum() < 10: # if mask is nearly empty 67 | return torch.zeros(mask.shape).bool() 68 | if not self.is_train: 69 | return self.draw_eval(mask=mask, box=box) 70 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use 71 | preset = Circle.get_stroke_preset(stroke_preset_name) 72 | nStroke = min(random.randint(1, self.num_stroke), mask.sum().item()) 73 | h,w = mask.shape 74 | points = self.get_random_points_from_mask(mask, n=nStroke) 75 | rand_mask = get_mask_by_input_strokes( 76 | init_points=points, 77 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset) 78 | rand_mask = (~torch.from_numpy(rand_mask)) * mask 79 | return rand_mask 80 | 81 | def draw_eval(self, mask=None, box=None): 82 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use 83 | preset = Circle.get_stroke_preset(stroke_preset_name) 84 | nStroke = min(self.max_eval, mask.sum().item()) 85 | h,w = mask.shape 86 | points = self.get_random_points_from_mask(mask, n=nStroke) 87 | rand_masks = [] 88 | for i in range(len(points)): 89 | rand_mask = get_mask_by_input_strokes( 90 | init_points=points[:i+1], 91 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset) 92 | rand_masks += [(~torch.from_numpy(rand_mask)) * mask] 93 | return torch.stack(rand_masks) 94 | 95 | @staticmethod 96 | def draw_by_points(points, mask, h, w): 97 | stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use 98 | preset = Circle.get_stroke_preset(stroke_preset_name) 99 | rand_mask = get_mask_by_input_strokes( 100 | init_points=points, 101 | imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,] 102 | rand_masks = (~torch.from_numpy(rand_mask)) * mask 103 | return rand_masks 104 | 105 | def __repr__(self,): 106 | return 'circle' -------------------------------------------------------------------------------- /datasets/visual_sampler/point.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy import ndimage 6 | 7 | 8 | class Point: 9 | def __init__(self, cfg, is_train=True): 10 | self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS'] 11 | self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] 12 | self.is_train = is_train 13 | 14 | def draw(self, mask=None, box=None): 15 | if mask.sum() < 10: 16 | return torch.zeros(mask.shape).bool() # if mask is empty 17 | if not self.is_train: 18 | return self.draw_eval(mask=mask, box=box) 19 | max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number 20 | num_points = random.randint(1, max_points) # get a random number of points 21 | h,w = mask.shape 22 | view_mask = mask.view(-1) 23 | non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask 24 | selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id 25 | non_zero_idx = non_zero_idx[selected_idx] # select non-zero index 26 | rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask 27 | rand_mask[non_zero_idx] = True # get non zero place to zero 28 | # dilate 29 | # struct = ndimage.generate_binary_structure(2, 2) 30 | # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype))) 31 | # return rand_mask 32 | return rand_mask.reshape(h, w) 33 | 34 | def draw_eval(self, mask=None, box=None): 35 | background = ~mask 36 | neg_num = min(self.max_eval // 2, background.sum().item()) 37 | pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1 38 | 39 | h,w = mask.shape 40 | view_mask = mask.view(-1) 41 | non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask 42 | selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id 43 | non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index 44 | pos_idx = torch.ones(non_zero_idx_pos.shape) 45 | 46 | view_background = background.view(-1) 47 | non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask 48 | selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id 49 | non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index 50 | neg_idx = torch.ones(non_zero_idx_neg.shape) * -1 51 | 52 | non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg]) 53 | idx = torch.cat([pos_idx, neg_idx]) 54 | rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long() 55 | non_zero_idx = non_zero_idx[rand_idx] 56 | idx = idx[rand_idx] 57 | 58 | rand_masks = [] 59 | for i in range(0, len(non_zero_idx)): 60 | rand_mask = torch.zeros(view_mask.shape) # init rand mask 61 | rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero 62 | # struct = ndimage.generate_binary_structure(2, 2) 63 | # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype))) 64 | rand_masks += [rand_mask.reshape(h, w)] 65 | 66 | # kernel_size = 3 67 | rand_masks = torch.stack(rand_masks) 68 | # rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0] 69 | # rand_masks[rand_masks>0] = 1 70 | # rand_masks[rand_masks<0] = -1 71 | return rand_masks 72 | 73 | def __repr__(self,): 74 | return 'point' -------------------------------------------------------------------------------- /datasets/visual_sampler/sampler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .point import Point 8 | from .polygon import Polygon 9 | from .scribble import Scribble 10 | from .circle import Circle 11 | 12 | from modeling.utils import configurable 13 | 14 | 15 | class ShapeSampler(nn.Module): 16 | @configurable 17 | def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True): 18 | super().__init__() 19 | self.max_candidate = max_candidate 20 | self.shape_prob = shape_prob 21 | self.shape_candidate = shape_candidate 22 | self.is_train = is_train 23 | 24 | @classmethod 25 | def from_config(cls, cfg, is_train=True, mode=None): 26 | max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE'] 27 | candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS'] 28 | candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES'] 29 | 30 | if mode == 'hack_train': 31 | candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names] 32 | else: 33 | # overwrite condidate_prob 34 | if not is_train: 35 | candidate_probs = [0.0 for x in range(len(candidate_names))] 36 | candidate_probs[candidate_names.index(mode)] = 1.0 37 | candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names] 38 | 39 | # Build augmentation 40 | return { 41 | "max_candidate": max_candidate, 42 | "shape_prob": candidate_probs, 43 | "shape_candidate": candidate_classes, 44 | "is_train": is_train, 45 | } 46 | 47 | def forward(self, instances): 48 | masks = instances.gt_masks.tensor 49 | boxes = instances.gt_boxes.tensor 50 | 51 | if len(masks) == 0: 52 | gt_masks = torch.zeros(masks.shape[-2:]).bool() 53 | rand_masks = torch.zeros(masks.shape[-2:]).bool() 54 | return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']} 55 | indices = [x for x in range(len(masks))] 56 | 57 | if self.is_train: 58 | random.shuffle(indices) 59 | candidate_mask = masks[indices[:self.max_candidate]] 60 | candidate_box = boxes[indices[:self.max_candidate]] 61 | else: 62 | candidate_mask = masks 63 | candidate_box = boxes 64 | 65 | draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask)) 66 | rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)] 67 | types = [repr(x) for x in draw_funcs] 68 | for i in range(0, len(rand_shapes)): 69 | if rand_shapes[i].sum() == 0: 70 | candidate_mask[i] = candidate_mask[i] * 0 71 | types[i] = 'none' 72 | 73 | # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c) 74 | return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self} 75 | 76 | def build_shape_sampler(cfg, **kwargs): 77 | return ShapeSampler(cfg, **kwargs) -------------------------------------------------------------------------------- /datasets/visual_sampler/scribble.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from .mask_generators import get_mask_by_input_strokes 6 | 7 | class Scribble: 8 | def __init__(self, cfg, is_train): 9 | self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES'] 10 | self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET'] 11 | self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB'] 12 | self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] 13 | self.is_train = is_train 14 | 15 | @staticmethod 16 | def get_stroke_preset(stroke_preset): 17 | if stroke_preset == 'rand_curve': 18 | return { 19 | "nVertexBound": [10, 30], 20 | "maxHeadSpeed": 20, 21 | "maxHeadAcceleration": (15, 0.5), 22 | "brushWidthBound": (3, 10), 23 | "nMovePointRatio": 0.5, 24 | "maxPiontMove": 3, 25 | "maxLineAcceleration": (5, 0.5), 26 | "boarderGap": None, 27 | "maxInitSpeed": 6 28 | } 29 | elif stroke_preset == 'rand_curve_small': 30 | return { 31 | "nVertexBound": [6, 22], 32 | "maxHeadSpeed": 12, 33 | "maxHeadAcceleration": (8, 0.5), 34 | "brushWidthBound": (2.5, 5), 35 | "nMovePointRatio": 0.5, 36 | "maxPiontMove": 1.5, 37 | "maxLineAcceleration": (3, 0.5), 38 | "boarderGap": None, 39 | "maxInitSpeed": 3 40 | } 41 | else: 42 | raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.') 43 | 44 | def get_random_points_from_mask(self, mask, n=5): 45 | h,w = mask.shape 46 | view_mask = mask.reshape(h*w) 47 | non_zero_idx = view_mask.nonzero()[:,0] 48 | selected_idx = torch.randperm(len(non_zero_idx))[:n] 49 | non_zero_idx = non_zero_idx[selected_idx] 50 | y = (non_zero_idx // w)*1.0 51 | x = (non_zero_idx % w)*1.0 52 | return torch.cat((x[:,None], y[:,None]), dim=1).numpy() 53 | 54 | def draw(self, mask=None, box=None): 55 | if mask.sum() < 10: 56 | return torch.zeros(mask.shape).bool() # if mask is empty 57 | if not self.is_train: 58 | return self.draw_eval(mask=mask, box=box) 59 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] 60 | preset = Scribble.get_stroke_preset(stroke_preset_name) 61 | nStroke = random.randint(1, min(self.num_stroke, mask.sum().item())) 62 | h,w = mask.shape 63 | points = self.get_random_points_from_mask(mask, n=nStroke) 64 | rand_mask = get_mask_by_input_strokes( 65 | init_points=points, 66 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset) 67 | rand_mask = (~torch.from_numpy(rand_mask)) * mask 68 | return rand_mask 69 | 70 | def draw_eval(self, mask=None, box=None): 71 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] 72 | preset = Scribble.get_stroke_preset(stroke_preset_name) 73 | nStroke = min(self.eval_stroke, mask.sum().item()) 74 | h,w = mask.shape 75 | points = self.get_random_points_from_mask(mask, n=nStroke) 76 | rand_masks = [] 77 | for i in range(len(points)): 78 | rand_mask = get_mask_by_input_strokes( 79 | init_points=points[:i+1], 80 | imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset) 81 | rand_mask = (~torch.from_numpy(rand_mask)) * mask 82 | rand_masks += [rand_mask] 83 | return torch.stack(rand_masks) 84 | 85 | @staticmethod 86 | def draw_by_points(points, mask, h, w): 87 | stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0] 88 | preset = Scribble.get_stroke_preset(stroke_preset_name) 89 | rand_mask = get_mask_by_input_strokes( 90 | init_points=points, 91 | imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,] 92 | rand_masks = (~torch.from_numpy(rand_mask)) * mask 93 | return rand_masks 94 | 95 | def __repr__(self,): 96 | return 'scribble' -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/__init__.py -------------------------------------------------------------------------------- /demo/seem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/__init__.py -------------------------------------------------------------------------------- /demo/seem/examples/corgi1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/corgi1.webp -------------------------------------------------------------------------------- /demo/seem/examples/corgi2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/corgi2.jpg -------------------------------------------------------------------------------- /demo/seem/examples/fries1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/fries1.png -------------------------------------------------------------------------------- /demo/seem/examples/fries2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/fries2.png -------------------------------------------------------------------------------- /demo/seem/examples/minecraft1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/minecraft1.jpg -------------------------------------------------------------------------------- /demo/seem/examples/placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/placeholder.png -------------------------------------------------------------------------------- /demo/seem/examples/ref_vase.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/ref_vase.JPG -------------------------------------------------------------------------------- /demo/seem/examples/river1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river1.png -------------------------------------------------------------------------------- /demo/seem/examples/river1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river1.wav -------------------------------------------------------------------------------- /demo/seem/examples/river1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river1_mask.png -------------------------------------------------------------------------------- /demo/seem/examples/river2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river2.png -------------------------------------------------------------------------------- /demo/seem/examples/vasedeck.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/vasedeck.mp4 -------------------------------------------------------------------------------- /demo/seem/examples/zebras1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/zebras1.jpg -------------------------------------------------------------------------------- /demo/seem/examples/zebras2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/zebras2.jpg -------------------------------------------------------------------------------- /demo/seem/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive import interactive_infer_video, interactive_infer_image -------------------------------------------------------------------------------- /entry.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import torch 11 | import logging 12 | import wandb 13 | 14 | from utils.arguments import load_opt_command 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | logger = logging.getLogger(__name__) 18 | 19 | def init_wandb(args, job_dir, entity='xueyanz', project='xdecoder', job_name='tmp'): 20 | wandb_dir = os.path.join(job_dir, 'wandb') 21 | os.makedirs(wandb_dir, exist_ok=True) 22 | runid = None 23 | if os.path.exists(f"{wandb_dir}/runid.txt"): 24 | runid = open(f"{wandb_dir}/runid.txt").read() 25 | 26 | wandb.init(project=project, 27 | name=job_name, 28 | dir=wandb_dir, 29 | entity=entity, 30 | resume="allow", 31 | id=runid, 32 | config={"hierarchical": True},) 33 | 34 | open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id) 35 | wandb.config.update({k: args[k] for k in args if k not in wandb.config}) 36 | 37 | def main(args=None): 38 | ''' 39 | [Main function for the entry point] 40 | 1. Set environment variables for distributed training. 41 | 2. Load the config file and set up the trainer. 42 | ''' 43 | 44 | opt, cmdline_args = load_opt_command(args) 45 | command = cmdline_args.command 46 | 47 | if cmdline_args.user_dir: 48 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 49 | opt['base_path'] = absolute_user_dir 50 | 51 | # update_opt(opt, command) 52 | world_size = 1 53 | if 'OMPI_COMM_WORLD_SIZE' in os.environ: 54 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 55 | 56 | if opt['TRAINER'] == 'xdecoder': 57 | from trainer import XDecoder_Trainer as Trainer 58 | else: 59 | assert False, "The trainer type: {} is not defined!".format(opt['TRAINER']) 60 | 61 | trainer = Trainer(opt) 62 | os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL' 63 | 64 | if command == "train": 65 | if opt['rank'] == 0 and opt['WANDB']: 66 | wandb.login(key=os.environ['WANDB_KEY']) 67 | init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder) 68 | trainer.train() 69 | elif command == "evaluate": 70 | trainer.eval() 71 | else: 72 | raise ValueError(f"Unknown command: {command}") 73 | 74 | if __name__ == "__main__": 75 | main() 76 | sys.exit(0) 77 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/__init__.py -------------------------------------------------------------------------------- /inference/images/animals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/animals.png -------------------------------------------------------------------------------- /inference/images/apples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/apples.jpg -------------------------------------------------------------------------------- /inference/images/coco/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/000.jpg -------------------------------------------------------------------------------- /inference/images/coco/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/001.jpg -------------------------------------------------------------------------------- /inference/images/coco/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/002.jpg -------------------------------------------------------------------------------- /inference/images/coco/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/003.jpg -------------------------------------------------------------------------------- /inference/images/fruit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/fruit.jpg -------------------------------------------------------------------------------- /inference/images/landscape.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/landscape.jpg -------------------------------------------------------------------------------- /inference/images/mountain.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/mountain.jpeg -------------------------------------------------------------------------------- /inference/images/owls.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/owls.jpeg -------------------------------------------------------------------------------- /inference/images/penguin.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/penguin.jpeg -------------------------------------------------------------------------------- /inference/images/region_retrieval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/region_retrieval.png -------------------------------------------------------------------------------- /inference/images/rose.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/rose.webp -------------------------------------------------------------------------------- /inference/images/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/street.jpg -------------------------------------------------------------------------------- /inference/images/teaser_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/teaser_new.png -------------------------------------------------------------------------------- /inference/xdecoder/infer_captioning.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | 12 | pth = '/'.join(sys.path[0].split('/')[:-1]) 13 | sys.path.insert(0, pth) 14 | 15 | from PIL import Image 16 | import numpy as np 17 | np.random.seed(0) 18 | import cv2 19 | 20 | import torch 21 | from torchvision import transforms 22 | 23 | from utils.arguments import load_opt_command 24 | from detectron2.data import MetadataCatalog 25 | from detectron2.structures import BitMasks 26 | from modeling.BaseModel import BaseModel 27 | from modeling import build_model 28 | from detectron2.utils.colormap import random_color 29 | from utils.visualizer import Visualizer 30 | from utils.distributed import init_distributed 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def main(args=None): 36 | ''' 37 | Main execution point for PyLearn. 38 | ''' 39 | 40 | opt, cmdline_args = load_opt_command(args) 41 | if cmdline_args.user_dir: 42 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 43 | opt['base_path'] = absolute_user_dir 44 | opt = init_distributed(opt) 45 | 46 | # META DATA 47 | pretrained_pth = os.path.join(opt['RESUME_FROM']) 48 | if 'novg' not in pretrained_pth: 49 | assert False, "Using the ckpt without visual genome training data will be much better." 50 | output_root = './output' 51 | image_pth = 'inference/images/mountain.jpeg' 52 | 53 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() 54 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background"], is_eval=False) 55 | 56 | t = [] 57 | t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) 58 | transform = transforms.Compose(t) 59 | 60 | with torch.no_grad(): 61 | image_ori = Image.open(image_pth).convert("RGB") 62 | width = image_ori.size[0] 63 | height = image_ori.size[1] 64 | image = transform(image_ori) 65 | image = np.asarray(image) 66 | image_ori = np.asarray(image_ori) 67 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() 68 | 69 | batch_inputs = [{'image': images, 'height': height, 'width': width, 'image_id': 0}] 70 | outputs = model.model.evaluate_captioning(batch_inputs) 71 | text = outputs[-1]['captioning_text'] 72 | 73 | image_ori = image_ori[:,:,::-1].copy() 74 | cv2.rectangle(image_ori, (0, 0), (width, 60), (0,0,0), -1) 75 | font = cv2.FONT_HERSHEY_DUPLEX 76 | fontScale = 1.2 77 | thickness = 2 78 | lineType = 2 79 | bottomLeftCornerOfText = (10, 40) 80 | fontColor = [255,255,255] 81 | cv2.putText(image_ori, text, 82 | bottomLeftCornerOfText, 83 | font, 84 | fontScale, 85 | fontColor, 86 | thickness, 87 | lineType) 88 | 89 | if not os.path.exists(output_root): 90 | os.makedirs(output_root) 91 | cv2.imwrite(os.path.join(output_root, 'captioning.png'), image_ori) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | sys.exit(0) -------------------------------------------------------------------------------- /inference/xdecoder/infer_instseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | 12 | pth = '/'.join(sys.path[0].split('/')[:-1]) 13 | sys.path.insert(0, pth) 14 | 15 | from PIL import Image 16 | import numpy as np 17 | np.random.seed(2) 18 | 19 | import torch 20 | from torchvision import transforms 21 | 22 | from utils.arguments import load_opt_command 23 | 24 | from detectron2.data import MetadataCatalog 25 | from detectron2.structures import BitMasks 26 | from modeling.BaseModel import BaseModel 27 | from modeling import build_model 28 | from detectron2.utils.colormap import random_color 29 | from utils.visualizer import Visualizer 30 | from utils.distributed import init_distributed 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def main(args=None): 36 | ''' 37 | Main execution point for PyLearn. 38 | ''' 39 | opt, cmdline_args = load_opt_command(args) 40 | if cmdline_args.user_dir: 41 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 42 | opt['base_path'] = absolute_user_dir 43 | 44 | opt = init_distributed(opt) 45 | 46 | # META DATA 47 | pretrained_pth = os.path.join(opt['RESUME_FROM']) 48 | output_root = './output' 49 | image_pth = 'inference/images/owls.jpeg' 50 | 51 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() 52 | 53 | t = [] 54 | t.append(transforms.Resize(800, interpolation=Image.BICUBIC)) 55 | transform = transforms.Compose(t) 56 | 57 | thing_classes = ["owl"] 58 | thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))] 59 | thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))} 60 | 61 | MetadataCatalog.get("demo").set( 62 | thing_colors=thing_colors, 63 | thing_classes=thing_classes, 64 | thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id, 65 | ) 66 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=False) 67 | metadata = MetadataCatalog.get('demo') 68 | model.model.metadata = metadata 69 | model.model.sem_seg_head.num_classes = len(thing_classes) 70 | 71 | with torch.no_grad(): 72 | image_ori = Image.open(image_pth).convert('RGB') 73 | width = image_ori.size[0] 74 | height = image_ori.size[1] 75 | image = transform(image_ori) 76 | image = np.asarray(image) 77 | image_ori = np.asarray(image_ori) 78 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() 79 | 80 | batch_inputs = [{'image': images, 'height': height, 'width': width}] 81 | outputs = model.forward(batch_inputs) 82 | visual = Visualizer(image_ori, metadata=metadata) 83 | 84 | inst_seg = outputs[-1]['instances'] 85 | inst_seg.pred_masks = inst_seg.pred_masks.cpu() 86 | inst_seg.pred_boxes = BitMasks(inst_seg.pred_masks > 0).get_bounding_boxes() 87 | demo = visual.draw_instance_predictions(inst_seg) # rgb Image 88 | 89 | if not os.path.exists(output_root): 90 | os.makedirs(output_root) 91 | demo.save(os.path.join(output_root, 'inst.png')) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | sys.exit(0) -------------------------------------------------------------------------------- /inference/xdecoder/infer_panoseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | 12 | pth = '/'.join(sys.path[0].split('/')[:-1]) 13 | sys.path.insert(0, pth) 14 | 15 | from PIL import Image 16 | import numpy as np 17 | np.random.seed(1) 18 | 19 | import torch 20 | from torchvision import transforms 21 | 22 | from utils.arguments import load_opt_command 23 | 24 | from detectron2.data import MetadataCatalog 25 | from detectron2.utils.colormap import random_color 26 | from modeling.BaseModel import BaseModel 27 | from modeling import build_model 28 | from utils.visualizer import Visualizer 29 | from utils.distributed import init_distributed 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def main(args=None): 35 | ''' 36 | Main execution point for PyLearn. 37 | ''' 38 | opt, cmdline_args = load_opt_command(args) 39 | if cmdline_args.user_dir: 40 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 41 | opt['base_path'] = absolute_user_dir 42 | 43 | opt = init_distributed(opt) 44 | 45 | # META DATA 46 | pretrained_pth = os.path.join(opt['RESUME_FROM']) 47 | output_root = './output' 48 | image_pth = 'inference/images/street.jpg' 49 | 50 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() 51 | 52 | t = [] 53 | t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) 54 | transform = transforms.Compose(t) 55 | 56 | thing_classes = ['car','person','traffic light', 'truck', 'motorcycle'] 57 | stuff_classes = ['building','sky','street','tree','rock','sidewalk'] 58 | thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))] 59 | stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(stuff_classes))] 60 | thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))} 61 | stuff_dataset_id_to_contiguous_id = {x+len(thing_classes):x for x in range(len(stuff_classes))} 62 | 63 | MetadataCatalog.get("demo").set( 64 | thing_colors=thing_colors, 65 | thing_classes=thing_classes, 66 | thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id, 67 | stuff_colors=stuff_colors, 68 | stuff_classes=stuff_classes, 69 | stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id, 70 | ) 71 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + stuff_classes + ["background"], is_eval=False) 72 | metadata = MetadataCatalog.get('demo') 73 | model.model.metadata = metadata 74 | model.model.sem_seg_head.num_classes = len(thing_classes + stuff_classes) 75 | 76 | with torch.no_grad(): 77 | image_ori = Image.open(image_pth).convert("RGB") 78 | width = image_ori.size[0] 79 | height = image_ori.size[1] 80 | image = transform(image_ori) 81 | image = np.asarray(image) 82 | image_ori = np.asarray(image_ori) 83 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() 84 | 85 | batch_inputs = [{'image': images, 'height': height, 'width': width}] 86 | outputs = model.forward(batch_inputs) 87 | visual = Visualizer(image_ori, metadata=metadata) 88 | 89 | pano_seg = outputs[-1]['panoptic_seg'][0] 90 | pano_seg_info = outputs[-1]['panoptic_seg'][1] 91 | 92 | for i in range(len(pano_seg_info)): 93 | if pano_seg_info[i]['category_id'] in metadata.thing_dataset_id_to_contiguous_id.keys(): 94 | pano_seg_info[i]['category_id'] = metadata.thing_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']] 95 | else: 96 | pano_seg_info[i]['isthing'] = False 97 | pano_seg_info[i]['category_id'] = metadata.stuff_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']] 98 | 99 | demo = visual.draw_panoptic_seg(pano_seg.cpu(), pano_seg_info) # rgb Image 100 | 101 | if not os.path.exists(output_root): 102 | os.makedirs(output_root) 103 | demo.save(os.path.join(output_root, 'pano.png')) 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | sys.exit(0) -------------------------------------------------------------------------------- /inference/xdecoder/infer_refseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import json 11 | import logging 12 | 13 | pth = '/'.join(sys.path[0].split('/')[:-1]) 14 | sys.path.insert(0, pth) 15 | 16 | from PIL import Image 17 | import numpy as np 18 | np.random.seed(27) 19 | 20 | import torch 21 | from torchvision import transforms 22 | 23 | from utils.arguments import load_opt_command 24 | 25 | from detectron2.data import MetadataCatalog 26 | from detectron2.utils.colormap import random_color 27 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 28 | from modeling.BaseModel import BaseModel 29 | from modeling import build_model 30 | from utils.visualizer import Visualizer 31 | from utils.distributed import init_distributed 32 | 33 | # logging.basicConfig(level = logging.INFO) 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def main(args=None): 38 | ''' 39 | Main execution point for PyLearn. 40 | ''' 41 | opt, cmdline_args = load_opt_command(args) 42 | if cmdline_args.user_dir: 43 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 44 | opt['base_path'] = absolute_user_dir 45 | 46 | opt = init_distributed(opt) 47 | 48 | # META DATA 49 | pretrained_pth = os.path.join(opt['RESUME_FROM']) 50 | output_root = './output' 51 | image_pth = 'inference/images/fruit.jpg' 52 | 53 | text = [['The larger watermelon.'], ['The front white flower.'], ['White tea pot.'], ['Flower bunch.'], ['white vase.'], ['The left peach.'], ['The brown knife.']] 54 | 55 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() 56 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=False) 57 | 58 | t = [] 59 | t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) 60 | transform = transforms.Compose(t) 61 | 62 | metadata = MetadataCatalog.get('ade20k_panoptic_train') 63 | model.model.metadata = metadata 64 | 65 | with torch.no_grad(): 66 | image_ori = Image.open(image_pth) 67 | width = image_ori.size[0] 68 | height = image_ori.size[1] 69 | image = transform(image_ori) 70 | image = np.asarray(image) 71 | image_ori = np.asarray(image_ori) 72 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() 73 | 74 | batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': text}}] 75 | outputs = model.model.evaluate_grounding(batch_inputs, None) 76 | visual = Visualizer(image_ori, metadata=metadata) 77 | 78 | grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() 79 | for idx, mask in enumerate(grd_mask): 80 | demo = visual.draw_binary_mask(mask, color=random_color(rgb=True, maximum=1).astype(np.int).tolist(), text=text[idx], alpha=0.3) 81 | 82 | output_folder = os.path.join(os.path.join(output_root)) 83 | if not os.path.exists(output_folder): 84 | os.makedirs(output_folder) 85 | demo.save(os.path.join(output_folder, 'refseg.png')) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | sys.exit(0) -------------------------------------------------------------------------------- /inference/xdecoder/infer_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | 12 | pth = '/'.join(sys.path[0].split('/')[:-1]) 13 | sys.path.insert(0, pth) 14 | 15 | from PIL import Image 16 | import numpy as np 17 | np.random.seed(1) 18 | 19 | import torch 20 | from torchvision import transforms 21 | 22 | from utils.arguments import load_opt_command 23 | 24 | from detectron2.data import MetadataCatalog 25 | from detectron2.utils.colormap import random_color 26 | from modeling.BaseModel import BaseModel 27 | from modeling import build_model 28 | from utils.visualizer import Visualizer 29 | from utils.distributed import init_distributed 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def main(args=None): 35 | ''' 36 | Main execution point for PyLearn. 37 | ''' 38 | opt, cmdline_args = load_opt_command(args) 39 | if cmdline_args.user_dir: 40 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir) 41 | opt['base_path'] = absolute_user_dir 42 | opt = init_distributed(opt) 43 | 44 | # META DATA 45 | pretrained_pth = os.path.join(opt['RESUME_FROM']) 46 | output_root = './output' 47 | image_pth = 'inference/images/animals.png' 48 | 49 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() 50 | 51 | t = [] 52 | t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) 53 | transform = transforms.Compose(t) 54 | 55 | stuff_classes = ['zebra','antelope','giraffe','ostrich','sky','water','grass','sand','tree'] 56 | stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(stuff_classes))] 57 | stuff_dataset_id_to_contiguous_id = {x:x for x in range(len(stuff_classes))} 58 | 59 | MetadataCatalog.get("demo").set( 60 | stuff_colors=stuff_colors, 61 | stuff_classes=stuff_classes, 62 | stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id, 63 | ) 64 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(stuff_classes + ["background"], is_eval=True) 65 | metadata = MetadataCatalog.get('demo') 66 | model.model.metadata = metadata 67 | model.model.sem_seg_head.num_classes = len(stuff_classes) 68 | 69 | with torch.no_grad(): 70 | image_ori = Image.open(image_pth).convert("RGB") 71 | width = image_ori.size[0] 72 | height = image_ori.size[1] 73 | image = transform(image_ori) 74 | image = np.asarray(image) 75 | image_ori = np.asarray(image_ori) 76 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() 77 | 78 | batch_inputs = [{'image': images, 'height': height, 'width': width}] 79 | outputs = model.forward(batch_inputs) 80 | visual = Visualizer(image_ori, metadata=metadata) 81 | 82 | sem_seg = outputs[-1]['sem_seg'].max(0)[1] 83 | demo = visual.draw_sem_seg(sem_seg.cpu(), alpha=0.5) # rgb Image 84 | 85 | if not os.path.exists(output_root): 86 | os.makedirs(output_root) 87 | demo.save(os.path.join(output_root, 'sem.png')) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | sys.exit(0) -------------------------------------------------------------------------------- /modeling/BaseModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from utils.model import align_and_update_state_dicts 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BaseModel(nn.Module): 13 | def __init__(self, opt, module: nn.Module): 14 | super(BaseModel, self).__init__() 15 | self.opt = opt 16 | self.model = module 17 | 18 | def forward(self, *inputs, **kwargs): 19 | outputs = self.model(*inputs, **kwargs) 20 | return outputs 21 | 22 | def save_pretrained(self, save_dir): 23 | torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt")) 24 | 25 | def from_pretrained(self, load_dir): 26 | state_dict = torch.load(load_dir, map_location=self.opt['device']) 27 | state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) 28 | self.model.load_state_dict(state_dict, strict=False) 29 | return self -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import build_model -------------------------------------------------------------------------------- /modeling/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder_model import * 2 | from .seem_model_v0 import * 3 | from .seem_model_v1 import * 4 | from .seem_model_demo import * 5 | from .build import build_model -------------------------------------------------------------------------------- /modeling/architectures/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def build_model(config, **kwargs): 5 | model_name = config['MODEL']['NAME'] 6 | 7 | if not is_model(model_name): 8 | raise ValueError(f'Unkown model: {model_name}') 9 | 10 | return model_entrypoints(model_name)(config, **kwargs) 11 | 12 | def register_model(fn): 13 | module_name_split = fn.__module__.split('.') 14 | model_name = module_name_split[-1] 15 | _model_entrypoints[model_name] = fn 16 | return fn 17 | 18 | def model_entrypoints(model_name): 19 | return _model_entrypoints[model_name] 20 | 21 | def is_model(model_name): 22 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/architectures/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_model(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone 2 | 3 | from .focal import * 4 | from .focal_dw import * 5 | from .backbone import * -------------------------------------------------------------------------------- /modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | 4 | from detectron2.modeling import ShapeSpec 5 | 6 | # from ..layers import ShapeSpec 7 | 8 | __all__ = ["Backbone"] 9 | 10 | 11 | class Backbone(nn.Module): 12 | """ 13 | Abstract base class for network backbones. 14 | """ 15 | 16 | def __init__(self): 17 | """ 18 | The `__init__` method of any subclass can specify its own set of arguments. 19 | """ 20 | super().__init__() 21 | 22 | def forward(self): 23 | """ 24 | Subclasses must override this method, but adhere to the same return type. 25 | 26 | Returns: 27 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor 28 | """ 29 | pass 30 | 31 | @property 32 | def size_divisibility(self) -> int: 33 | """ 34 | Some backbones require the input height and width to be divisible by a 35 | specific integer. This is typically true for encoder / decoder type networks 36 | with lateral connection (e.g., FPN) for which feature maps need to match 37 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific 38 | input size divisibility is required. 39 | """ 40 | return 0 41 | 42 | def output_shape(self): 43 | """ 44 | Returns: 45 | dict[str->ShapeSpec] 46 | """ 47 | # this is a backward-compatible default 48 | return { 49 | name: ShapeSpec( 50 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 51 | ) 52 | for name in self._out_features 53 | } 54 | -------------------------------------------------------------------------------- /modeling/backbone/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | from .backbone import * 5 | 6 | def build_backbone(config, **kwargs): 7 | model_name = config['MODEL']['BACKBONE']['NAME'] 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /modeling/backbone/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_backbone(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints 15 | -------------------------------------------------------------------------------- /modeling/body/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder_head import * 2 | from .build import * 3 | 4 | def build_xdecoder_head(config, *args, **kwargs): 5 | model_name = config['MODEL']['HEAD'] 6 | if not is_model(model_name): 7 | raise ValueError(f'Unkown model: {model_name}') 8 | 9 | body = model_entrypoints(model_name)(config, *args, **kwargs) 10 | return body -------------------------------------------------------------------------------- /modeling/body/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_body(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/body/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_decoder 2 | from .xdecoder import * -------------------------------------------------------------------------------- /modeling/body/decoder/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | 5 | def build_decoder(config, *args, **kwargs): 6 | model_name = config['MODEL']['DECODER']['NAME'] 7 | 8 | if not is_model(model_name): 9 | raise ValueError(f'Unkown model: {model_name}') 10 | 11 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /modeling/body/decoder/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_decoder(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/body/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_encoder -------------------------------------------------------------------------------- /modeling/body/encoder/build.py: -------------------------------------------------------------------------------- 1 | from .registry import model_entrypoints 2 | from .registry import is_model 3 | 4 | from .transformer_encoder_fpn import * 5 | 6 | def build_encoder(config, *args, **kwargs): 7 | model_name = config['MODEL']['ENCODER']['NAME'] 8 | 9 | if not is_model(model_name): 10 | raise ValueError(f'Unkown model: {model_name}') 11 | 12 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /modeling/body/encoder/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_encoder(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints 14 | -------------------------------------------------------------------------------- /modeling/body/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_body(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/body/xdecoder_head.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | from typing import Dict 9 | 10 | from torch import nn 11 | 12 | from detectron2.layers import ShapeSpec 13 | 14 | from .build import register_body 15 | from ..vision.encoder import build_encoder 16 | from ..interface import build_decoder 17 | from ..utils import configurable 18 | 19 | 20 | class XdecoderHead(nn.Module): 21 | 22 | @configurable 23 | def __init__( 24 | self, 25 | input_shape: Dict[str, ShapeSpec], 26 | *, 27 | num_classes: int, 28 | pixel_decoder: nn.Module, 29 | loss_weight: float = 1.0, 30 | ignore_value: int = -1, 31 | # extra parameters 32 | transformer_predictor: nn.Module, 33 | transformer_in_feature: str, 34 | ): 35 | """ 36 | NOTE: this interface is experimental. 37 | Args: 38 | input_shape: shapes (channels and stride) of the input features 39 | num_classes: number of classes to predict 40 | pixel_decoder: the pixel decoder module 41 | loss_weight: loss weight 42 | ignore_value: category id to be ignored during training. 43 | transformer_predictor: the transformer decoder that makes prediction 44 | transformer_in_feature: input feature name to the transformer_predictor 45 | """ 46 | super().__init__() 47 | 48 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 49 | self.in_features = [k for k, v in input_shape] 50 | feature_strides = [v.stride for k, v in input_shape] 51 | feature_channels = [v.channels for k, v in input_shape] 52 | 53 | self.ignore_value = ignore_value 54 | self.common_stride = 4 55 | self.loss_weight = loss_weight 56 | 57 | self.pixel_decoder = pixel_decoder 58 | self.predictor = transformer_predictor 59 | self.transformer_in_feature = transformer_in_feature 60 | 61 | self.num_classes = num_classes 62 | 63 | @classmethod 64 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): 65 | 66 | in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE'] 67 | enc_cfg = cfg['MODEL']['ENCODER'] 68 | dec_cfg = cfg['MODEL']['DECODER'] 69 | 70 | # figure out in_channels to transformer predictor 71 | if in_features_type == "transformer_encoder": 72 | transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] 73 | elif in_features_type == "pixel_embedding": 74 | transformer_predictor_in_channels = enc_cfg['MASK_DIM'] 75 | elif in_features_type == "multi_scale_pixel_decoder": 76 | transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] 77 | else: 78 | transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels 79 | 80 | return { 81 | "input_shape": { 82 | k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] 83 | }, 84 | "ignore_value": enc_cfg['IGNORE_VALUE'], 85 | "num_classes": enc_cfg.get('NUM_CLASSES', None), 86 | "pixel_decoder": build_encoder(cfg, input_shape), 87 | "loss_weight": enc_cfg['LOSS_WEIGHT'], 88 | "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'], 89 | "transformer_predictor": build_decoder( 90 | cfg, 91 | transformer_predictor_in_channels, 92 | lang_encoder, 93 | mask_classification=True, 94 | extra=extra, 95 | ), 96 | } 97 | 98 | def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): 99 | return self.layers(features, mask, target_queries, target_vlp, task, extra) 100 | 101 | def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): 102 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features) 103 | 104 | if self.transformer_in_feature == "multi_scale_pixel_decoder": 105 | predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra) 106 | else: 107 | if self.transformer_in_feature == "transformer_encoder": 108 | assert ( 109 | transformer_encoder_features is not None 110 | ), "Please use the TransformerEncoderPixelDecoder." 111 | predictions = self.predictor(transformer_encoder_features, mask_features, mask) 112 | elif self.transformer_in_feature == "pixel_embedding": 113 | predictions = self.predictor(mask_features, mask_features, mask) 114 | else: 115 | predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask) 116 | return predictions 117 | 118 | 119 | @register_body 120 | def get_xdecoder_head(cfg, input_shape, lang_encoder, extra): 121 | return XdecoderHead(cfg, input_shape, lang_encoder, extra) -------------------------------------------------------------------------------- /modeling/interface/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder import * 2 | from .seem_v0 import * 3 | from .seem_v1 import * 4 | from .seem_demo import * 5 | from .build import * 6 | 7 | def build_decoder(config, *args, **kwargs): 8 | model_name = config['MODEL']['DECODER']['NAME'] 9 | 10 | if not is_model(model_name): 11 | raise ValueError(f'Unkown model: {model_name}') 12 | 13 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /modeling/interface/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_decoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/interface/prototype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/modeling/interface/prototype/__init__.py -------------------------------------------------------------------------------- /modeling/language/LangEncoder/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTokenizer, CLIPTokenizerFast 2 | from transformers import AutoTokenizer 3 | 4 | from .transformer import * 5 | from .build import * 6 | 7 | 8 | def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): 9 | model_name = config_encoder['NAME'] 10 | 11 | if not is_lang_encoder(model_name): 12 | raise ValueError(f'Unkown model: {model_name}') 13 | 14 | return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) 15 | 16 | def build_tokenizer(config_encoder): 17 | tokenizer = None 18 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 19 | if config_encoder['TOKENIZER'] == 'clip': 20 | pretrained_tokenizer = config_encoder.get( 21 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 22 | ) 23 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) 24 | tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) 25 | elif config_encoder['TOKENIZER'] == 'clip-fast': 26 | pretrained_tokenizer = config_encoder.get( 27 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' 28 | ) 29 | tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) 30 | else: 31 | tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) 32 | 33 | return tokenizer -------------------------------------------------------------------------------- /modeling/language/LangEncoder/build.py: -------------------------------------------------------------------------------- 1 | _lang_encoders = {} 2 | 3 | 4 | def register_lang_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | 8 | _lang_encoders[model_name] = fn 9 | 10 | return fn 11 | 12 | def lang_encoders(model_name): 13 | return _lang_encoders[model_name] 14 | 15 | def is_lang_encoder(model_name): 16 | return model_name in _lang_encoders 17 | -------------------------------------------------------------------------------- /modeling/language/LangEncoder/registry.py: -------------------------------------------------------------------------------- 1 | _lang_encoders = {} 2 | 3 | 4 | def register_lang_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | 8 | _lang_encoders[model_name] = fn 9 | 10 | return fn 11 | 12 | 13 | def lang_encoders(model_name): 14 | return _lang_encoders[model_name] 15 | 16 | 17 | def is_lang_encoder(model_name): 18 | return model_name in _lang_encoders 19 | -------------------------------------------------------------------------------- /modeling/language/__init__.py: -------------------------------------------------------------------------------- 1 | from .vlpencoder import * 2 | from .build import * 3 | 4 | def build_language_encoder(config, **kwargs): 5 | model_name = config['MODEL']['TEXT']['ARCH'] 6 | 7 | if not is_model(model_name): 8 | raise ValueError(f'Unkown model: {model_name}') 9 | 10 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /modeling/language/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_model(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/language/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import nltk 5 | nltk.data.path.append('/mnt/data/nltk_data') 6 | import numpy as np 7 | 8 | from utils.constants import IMAGENET_DEFAULT_TEMPLATES 9 | 10 | 11 | def get_tag(tokenized, tags): 12 | if not isinstance(tags, (list, tuple)): 13 | tags = [tags] 14 | ret = [] 15 | for (word, pos) in nltk.pos_tag(tokenized): 16 | for tag in tags: 17 | if pos == tag: 18 | ret.append(word) 19 | return ret 20 | 21 | def get_noun_phrase(tokenized): 22 | # Taken from Su Nam Kim Paper... 23 | grammar = r""" 24 | NBAR: 25 | {*} # Nouns and Adjectives, terminated with Nouns 26 | 27 | NP: 28 | {} 29 | {} # Above, connected with in/of/etc... 30 | """ 31 | chunker = nltk.RegexpParser(grammar) 32 | 33 | chunked = chunker.parse(nltk.pos_tag(tokenized)) 34 | continuous_chunk = [] 35 | current_chunk = [] 36 | 37 | for subtree in chunked: 38 | if isinstance(subtree, nltk.Tree): 39 | current_chunk.append(' '.join([token for token, pos in subtree.leaves()])) 40 | elif current_chunk: 41 | named_entity = ' '.join(current_chunk) 42 | if named_entity not in continuous_chunk: 43 | continuous_chunk.append(named_entity) 44 | current_chunk = [] 45 | else: 46 | continue 47 | 48 | return continuous_chunk 49 | 50 | def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True): 51 | tokenized = nltk.word_tokenize(text) 52 | 53 | if random.random() >= phrase_prob: 54 | nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP']) 55 | else: 56 | nouns = get_noun_phrase(tokenized) 57 | 58 | 59 | prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns] 60 | 61 | if append_text: 62 | prompt_texts += [text] 63 | nouns += [text] 64 | 65 | return prompt_texts, nouns -------------------------------------------------------------------------------- /modeling/language/registry.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | def register_model(fn): 4 | module_name_split = fn.__module__.split('.') 5 | model_name = module_name_split[-1] 6 | _model_entrypoints[model_name] = fn 7 | return fn 8 | 9 | def model_entrypoints(model_name): 10 | return _model_entrypoints[model_name] 11 | 12 | def is_model(model_name): 13 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .point_features import * 2 | from .position_encoding import * 3 | from .postprocessing import * 4 | from .attention import * 5 | from .criterion import * 6 | from .matcher import * -------------------------------------------------------------------------------- /modeling/modules/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=x.dtype) 34 | x_embed = not_mask.cumsum(2, dtype=x.dtype) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | def __repr__(self, _repr_indent=4): 55 | head = "Positional encoding " + self.__class__.__name__ 56 | body = [ 57 | "num_pos_feats: {}".format(self.num_pos_feats), 58 | "temperature: {}".format(self.temperature), 59 | "normalize: {}".format(self.normalize), 60 | "scale: {}".format(self.scale), 61 | ] 62 | # _repr_indent = 4 63 | lines = [head] + [" " * _repr_indent + line for line in body] 64 | return "\n".join(lines) 65 | -------------------------------------------------------------------------------- /modeling/modules/postprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from detectron2.structures import Instances, ROIMasks 6 | 7 | 8 | # perhaps should rename to "resize_instance" 9 | def detector_postprocess( 10 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 11 | ): 12 | """ 13 | Resize the output instances. 14 | The input images are often resized when entering an object detector. 15 | As a result, we often need the outputs of the detector in a different 16 | resolution from its inputs. 17 | 18 | This function will resize the raw outputs of an R-CNN detector 19 | to produce outputs according to the desired output resolution. 20 | 21 | Args: 22 | results (Instances): the raw outputs from the detector. 23 | `results.image_size` contains the input image resolution the detector sees. 24 | This object might be modified in-place. 25 | output_height, output_width: the desired output resolution. 26 | 27 | Returns: 28 | Instances: the resized output from the model, based on the output resolution 29 | """ 30 | if isinstance(output_width, torch.Tensor): 31 | # This shape might (but not necessarily) be tensors during tracing. 32 | # Converts integer tensors to float temporaries to ensure true 33 | # division is performed when computing scale_x and scale_y. 34 | output_width_tmp = output_width.float() 35 | output_height_tmp = output_height.float() 36 | new_size = torch.stack([output_height, output_width]) 37 | else: 38 | new_size = (output_height, output_width) 39 | output_width_tmp = output_width 40 | output_height_tmp = output_height 41 | 42 | scale_x, scale_y = ( 43 | output_width_tmp / results.image_size[1], 44 | output_height_tmp / results.image_size[0], 45 | ) 46 | results = Instances(new_size, **results.get_fields()) 47 | 48 | if results.has("pred_boxes"): 49 | output_boxes = results.pred_boxes 50 | elif results.has("proposal_boxes"): 51 | output_boxes = results.proposal_boxes 52 | else: 53 | output_boxes = None 54 | assert output_boxes is not None, "Predictions must contain boxes!" 55 | 56 | output_boxes.scale(scale_x, scale_y) 57 | output_boxes.clip(results.image_size) 58 | 59 | results = results[output_boxes.nonempty()] 60 | 61 | if results.has("pred_masks"): 62 | if isinstance(results.pred_masks, ROIMasks): 63 | roi_masks = results.pred_masks 64 | else: 65 | # pred_masks is a tensor of shape (N, 1, M, M) 66 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) 67 | results.pred_masks = roi_masks.to_bitmasks( 68 | results.pred_boxes, output_height, output_width, mask_threshold 69 | ).tensor # TODO return ROIMasks/BitMask object in the future 70 | 71 | if results.has("pred_keypoints"): 72 | results.pred_keypoints[:, :, 0] *= scale_x 73 | results.pred_keypoints[:, :, 1] *= scale_y 74 | 75 | return results 76 | 77 | def bbox_postprocess(result, input_size, img_size, output_height, output_width): 78 | """ 79 | result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] 80 | """ 81 | if result is None: 82 | return None 83 | 84 | scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) 85 | result = result.sigmoid() * scale 86 | x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 87 | h,w = img_size 88 | 89 | x1 = x1.clamp(min=0, max=w) 90 | y1 = y1.clamp(min=0, max=h) 91 | x2 = x2.clamp(min=0, max=w) 92 | y2 = y2.clamp(min=0, max=h) 93 | 94 | box = torch.stack([x1,y1,x2,y2]).permute(1,0) 95 | scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) 96 | box = box*scale 97 | return box 98 | 99 | def sem_seg_postprocess(result, img_size, output_height, output_width): 100 | """ 101 | Return semantic segmentation predictions in the original resolution. 102 | 103 | The input images are often resized when entering semantic segmentor. Moreover, in same 104 | cases, they also padded inside segmentor to be divisible by maximum network stride. 105 | As a result, we often need the predictions of the segmentor in a different 106 | resolution from its inputs. 107 | 108 | Args: 109 | result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), 110 | where C is the number of classes, and H, W are the height and width of the prediction. 111 | img_size (tuple): image size that segmentor is taking as input. 112 | output_height, output_width: the desired output resolution. 113 | 114 | Returns: 115 | semantic segmentation prediction (Tensor): A tensor of the shape 116 | (C, output_height, output_width) that contains per-pixel soft predictions. 117 | """ 118 | result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) 119 | result = F.interpolate( 120 | result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True 121 | )[0] 122 | return result 123 | -------------------------------------------------------------------------------- /modeling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .misc import * 3 | from .interactive import * 4 | from .attention import * -------------------------------------------------------------------------------- /modeling/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | def box_xywh_to_xyxy(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [x0, y0, (x0 + x1), (y0 + y1)] 25 | return torch.stack(b, dim=-1) 26 | 27 | 28 | # modified from torchvision to also return the union 29 | def box_iou(boxes1, boxes2): 30 | area1 = box_area(boxes1) 31 | area2 = box_area(boxes2) 32 | 33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clamp(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | iou = inter / (union+1e-6) 42 | return iou, union 43 | 44 | 45 | def generalized_box_iou(boxes1, boxes2): 46 | """ 47 | Generalized IoU from https://giou.stanford.edu/ 48 | 49 | The boxes should be in [x0, y0, x1, y1] format 50 | 51 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 52 | and M = len(boxes2) 53 | """ 54 | # degenerate boxes gives inf / nan results 55 | # so do an early check 56 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 57 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 58 | iou, union = box_iou(boxes1, boxes2) 59 | 60 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 61 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 62 | 63 | wh = (rb - lt).clamp(min=0) # [N,M,2] 64 | area = wh[:, :, 0] * wh[:, :, 1] 65 | 66 | return iou - (area - union) / (area+1e-6) 67 | 68 | 69 | def masks_to_boxes(masks): 70 | """Compute the bounding boxes around the provided masks 71 | 72 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 73 | 74 | Returns a [N, 4] tensors, with the boxes in xyxy format 75 | """ 76 | if masks.numel() == 0: 77 | return torch.zeros((0, 4), device=masks.device) 78 | 79 | h, w = masks.shape[-2:] 80 | 81 | y = torch.arange(0, h, dtype=torch.float) 82 | x = torch.arange(0, w, dtype=torch.float) 83 | y, x = torch.meshgrid(y, x) 84 | 85 | x_mask = (masks * x.unsqueeze(0)) 86 | x_max = x_mask.flatten(1).max(-1)[0] 87 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 88 | 89 | y_mask = (masks * y.unsqueeze(0)) 90 | y_max = y_mask.flatten(1).max(-1)[0] 91 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 92 | 93 | return torch.stack([x_min, y_min, x_max, y_max], 1) -------------------------------------------------------------------------------- /modeling/utils/interactive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import math 4 | 5 | import torch 6 | from torch import nn, Tensor 7 | import torch.nn.functional as F 8 | 9 | 10 | def rand_sample(x, divisor, max_len): 11 | # non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']] 12 | if len(x.nonzero()) == 0: 13 | return x.nonzero().t() 14 | 15 | non_zero_point_index = (x.nonzero()/divisor).t() 16 | mask_ids = non_zero_point_index[0].unique().long() 17 | 18 | # compute probability for each samle 19 | probs = torch.zeros_like(non_zero_point_index[0]) 20 | for idx in mask_ids: 21 | prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum())) 22 | probs[non_zero_point_index[0]==idx] = prob 23 | 24 | indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0] 25 | non_zero_point_index = non_zero_point_index[:,indices] 26 | return non_zero_point_index # [n, 512] 27 | 28 | def rand_sample_plain(x, max_len): 29 | if x.shape[1] <= max_len: 30 | return x 31 | else: 32 | rand_idx = torch.randperm(x.shape[1])[:max_len] 33 | return x[:,rand_idx] 34 | 35 | def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed): 36 | src = [] 37 | pos = [] 38 | size_list = [] 39 | 40 | # disable mask, it does not affect performance 41 | for i in range(num_feature_levels): 42 | size_list.append(x[i].shape[-2:]) 43 | pos.append(pe_layer(x[i], None).flatten(2)) 44 | src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None]) 45 | 46 | # flatten NxCxHxW to HWxNxC 47 | pos[-1] = pos[-1].permute(2, 0, 1) 48 | src[-1] = src[-1].permute(2, 0, 1) 49 | return src, pos, size_list -------------------------------------------------------------------------------- /modeling/vision/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .focal import * 2 | from .focal_dw import * 3 | from .davit import * 4 | from .vit import * 5 | from .backbone import * 6 | from .build import * 7 | 8 | 9 | def build_backbone(config, **kwargs): 10 | model_name = config['MODEL']['BACKBONE']['NAME'] 11 | if not is_model(model_name): 12 | raise ValueError(f'Unkown model: {model_name}') 13 | 14 | return model_entrypoints(model_name)(config, **kwargs) -------------------------------------------------------------------------------- /modeling/vision/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | 4 | from detectron2.modeling import ShapeSpec 5 | 6 | # from ..layers import ShapeSpec 7 | 8 | __all__ = ["Backbone"] 9 | 10 | 11 | class Backbone(nn.Module): 12 | """ 13 | Abstract base class for network backbones. 14 | """ 15 | 16 | def __init__(self): 17 | """ 18 | The `__init__` method of any subclass can specify its own set of arguments. 19 | """ 20 | super().__init__() 21 | 22 | def forward(self): 23 | """ 24 | Subclasses must override this method, but adhere to the same return type. 25 | 26 | Returns: 27 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor 28 | """ 29 | pass 30 | 31 | @property 32 | def size_divisibility(self) -> int: 33 | """ 34 | Some backbones require the input height and width to be divisible by a 35 | specific integer. This is typically true for encoder / decoder type networks 36 | with lateral connection (e.g., FPN) for which feature maps need to match 37 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific 38 | input size divisibility is required. 39 | """ 40 | return 0 41 | 42 | def output_shape(self): 43 | """ 44 | Returns: 45 | dict[str->ShapeSpec] 46 | """ 47 | # this is a backward-compatible default 48 | return { 49 | name: ShapeSpec( 50 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 51 | ) 52 | for name in self._out_features 53 | } 54 | -------------------------------------------------------------------------------- /modeling/vision/backbone/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_backbone(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/vision/backbone/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /modeling/vision/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_encoder_fpn import * 2 | try: 3 | from .transformer_encoder_deform import * 4 | except: 5 | print('Deformable Transformer Encoder is not available.') 6 | from .build import * 7 | 8 | 9 | def build_encoder(config, *args, **kwargs): 10 | model_name = config['MODEL']['ENCODER']['NAME'] 11 | 12 | if not is_model(model_name): 13 | raise ValueError(f'Unkown model: {model_name}') 14 | 15 | return model_entrypoints(model_name)(config, *args, **kwargs) -------------------------------------------------------------------------------- /modeling/vision/encoder/build.py: -------------------------------------------------------------------------------- 1 | _model_entrypoints = {} 2 | 3 | 4 | def register_encoder(fn): 5 | module_name_split = fn.__module__.split('.') 6 | model_name = module_name_split[-1] 7 | _model_entrypoints[model_name] = fn 8 | return fn 9 | 10 | def model_entrypoints(model_name): 11 | return _model_entrypoints[model_name] 12 | 13 | def is_model(model_name): 14 | return model_name in _model_entrypoints -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd mask2former/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install --user 14 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /modeling/vision/encoder/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/pipeline/__init__.py -------------------------------------------------------------------------------- /pipeline/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | def hook_opt(opt): 7 | 8 | try: 9 | grounding_flag = opt['REF']['INPUT']['SPATIAL'] 10 | except: 11 | grounding_flag = False 12 | 13 | if grounding_flag: 14 | opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['grounding'] = ['queries_grounding', 'tokens_grounding', 'tokens_spatial'] 15 | 16 | try: 17 | spatial_flag = opt['STROKE_SAMPLER']['EVAL']['GROUNDING'] 18 | except: 19 | spatial_flag = False 20 | 21 | if spatial_flag: 22 | opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['spatial'] = ['queries_spatial', 'tokens_spatial', 'memories_spatial', 'tokens_grounding'] 23 | 24 | return opt 25 | 26 | # HACK for evalution 27 | def hook_metadata(metadata, name): 28 | return metadata 29 | 30 | # HACK for evalution 31 | def hook_switcher(model, name): 32 | mappings = {} 33 | if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'context_59_val_seg', 'context_459_val_seg', 'voc_2012_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']: 34 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False} 35 | elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name: 36 | mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False} 37 | elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']: 38 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True} 39 | elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']: 40 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True} 41 | else: 42 | if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd", "pascalvoc_val_Point", "grounding_coco_entity_val", "vlp_coco_entity_val"]: 43 | assert False, "dataset switcher is not defined" 44 | 45 | for key, value in mappings.items(): 46 | if key == 'SEMANTIC_ON': 47 | model.model.semantic_on = value 48 | if key == 'INSTANCE_ON': 49 | model.model.instance_on = value 50 | if key == 'PANOPTIC_ON': 51 | model.model.panoptic_on = value -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .xdecoder_trainer import * -------------------------------------------------------------------------------- /trainer/distributed_trainer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import logging 10 | from mpi4py import MPI 11 | 12 | import torch 13 | 14 | from .utils.hook import add_hook 15 | from .utils.mpi_adapter import MPIAdapter 16 | from .utils.misc import save_opt_to_yaml 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class DistributedTrainer: 22 | def __init__(self, opt): 23 | self.opt = opt 24 | 25 | # parse environment information for distributed training 26 | adapter = MPIAdapter(self.opt['PORT']) 27 | self.opt['world_size'] = adapter.world_size 28 | self.opt['local_size'] = adapter.local_size 29 | self.opt['rank'] = adapter.rank 30 | self.opt['local_rank'] = adapter.local_rank 31 | 32 | self.set_opt_hook() 33 | 34 | # set up device 35 | if not self.opt['CUDA']: 36 | self.opt['device'] = torch.device("cpu") 37 | logger.info("Using CPU") 38 | else: 39 | torch.cuda.set_device(self.opt['local_rank']) 40 | self.opt['device'] = torch.device("cuda", self.opt['local_rank']) 41 | logger.info("Using CUDA") 42 | 43 | # init distributed training 44 | adapter.log_info() 45 | if torch.distributed.is_available() and self.opt['world_size'] > 1: 46 | adapter.init_process_group(backend='nccl') 47 | 48 | # save config file 49 | self.save_folder = self.opt['SAVE_DIR'] 50 | 51 | if self.opt['world_size'] > 1: 52 | torch.distributed.barrier() 53 | 54 | if self.opt['rank'] == 0: 55 | os.makedirs(self.save_folder, exist_ok=True) 56 | 57 | logger.info(f"Save config file to {os.path.join(self.save_folder, 'conf_copy.yaml')}") 58 | save_opt_to_yaml(self.opt, os.path.join(self.save_folder, 'conf_copy.yaml')) 59 | 60 | # ddp: log stats and update learning rate 61 | self.grad_acc_steps = self.opt['GRADIENT_ACCUMULATE_STEP'] 62 | logger.info(f"Base learning rate: {self.opt['SOLVER']['BASE_LR']}") 63 | logger.info(f"Number of GPUs: {self.opt['world_size']}") 64 | logger.info(f"Gradient accumulation steps: {self.grad_acc_steps}") 65 | 66 | if self.opt['world_size'] > 1: 67 | add_hook() 68 | 69 | # prepare metadata for save folder 70 | conf_file = self.opt['conf_files'][0] 71 | if 'BASENAME' not in self.opt: 72 | self.opt['BASENAME'] = os.path.basename(conf_file) 73 | 74 | self.init_save_folder() 75 | 76 | def set_opt_hook(self): 77 | # Fill in the default values for required keywords 78 | self.opt['CUDA'] = self.opt.get('CUDA', True) and torch.cuda.is_available() 79 | self.opt['FP16'] = self.opt.get('FP16', False) and self.opt['CUDA'] 80 | self.opt['GRADIENT_ACCUMULATE_STEP'] = int(self.opt.get('GRADIENT_ACCUMULATE_STEP', 1)) 81 | self.opt['EVAL_PER_UPDATE_NUM'] = int(self.opt.get('EVAL_PER_UPDATE_NUM', 0)) 82 | self.opt['LR_SCHEDULER_PARAMS'] = self.opt.get('LR_SCHEDULER_PARAMS', {}) 83 | 84 | if 'SAVE_DIR' not in self.opt: 85 | assert False, "Please initialize SAVE_DIR in your config file." 86 | self.opt['SAVE_DIR'] = os.path.normpath(self.opt['SAVE_DIR']) 87 | logger.info(f"Setting SAVE_DIR as {self.opt['SAVE_DIR']}") 88 | 89 | def init_save_folder(self): 90 | """ 91 | Initialize the save folder for logs, model, checkpoint, and evaluation. 92 | """ 93 | runid = 1 94 | 95 | if self.opt['world_size'] > 1: 96 | torch.distributed.barrier() 97 | 98 | if self.opt['rank'] == 0: 99 | while True: 100 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}") 101 | try: 102 | os.makedirs(save_folder, exist_ok=False) 103 | break 104 | except FileExistsError: 105 | runid = runid + 1 106 | 107 | if self.opt['world_size'] > 1: 108 | torch.distributed.barrier() 109 | 110 | if self.opt['world_size'] > 1: 111 | runid = 1 112 | while True: 113 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}") 114 | if not os.path.exists(save_folder): 115 | break 116 | else: 117 | runid += 1 118 | 119 | runid -= 1 120 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}") 121 | # this second os.makedirs() call on all ranks is to force sync the save_folder creation between blobFuse and local fs 122 | os.makedirs(save_folder, exist_ok=True) 123 | 124 | self.save_folder = save_folder -------------------------------------------------------------------------------- /trainer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/trainer/utils/__init__.py -------------------------------------------------------------------------------- /trainer/utils/hook.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | _orig_except_hook = None 7 | 8 | 9 | def _global_except_hook(exctype, value, traceback): 10 | """Catches an unhandled exception and call MPI_Abort().""" 11 | try: 12 | if _orig_except_hook: 13 | _orig_except_hook(exctype, value, traceback) 14 | else: 15 | sys.__excepthook__(exctype, value, traceback) 16 | 17 | finally: 18 | import mpi4py.MPI 19 | rank = mpi4py.MPI.COMM_WORLD.Get_rank() 20 | logger.warning("******************************************") 21 | logger.warning("DefaultTrainer:") 22 | logger.warning(f" Uncaught exception on rank {rank}.") 23 | logger.warning(" Calling MPI_Abort() to shut down MPI...") 24 | logger.warning("******************************************") 25 | logging.shutdown() 26 | 27 | try: 28 | import mpi4py.MPI 29 | mpi4py.MPI.COMM_WORLD.Abort(1) 30 | except Exception as e: 31 | # Something is completely broken... 32 | # There's nothing we can do any more 33 | sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n") 34 | sys.stderr.flush() 35 | raise e 36 | 37 | 38 | def add_hook(): 39 | """ 40 | Add a global hook function that captures all unhandled exceptions. 41 | The function calls MPI_Abort() to force all processes abort. 42 | 43 | An MPI runtime is expected to kill all of its child processes 44 | if one of them exits abnormally or without calling `MPI_Finalize()`. 45 | However, when a Python program run on `mpi4py`, the MPI runtime 46 | often fails to detect a process failure, and the rest of the processes 47 | hang infinitely. 48 | 49 | See https://github.com/chainer/chainermn/issues/236 and 50 | https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more 51 | information. 52 | """ 53 | global _orig_except_hook 54 | 55 | if _orig_except_hook is not None: 56 | logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.") 57 | return 58 | 59 | logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.") 60 | _orig_except_hook = sys.excepthook 61 | sys.excepthook = _global_except_hook 62 | -------------------------------------------------------------------------------- /trainer/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from typing import Dict 4 | 5 | 6 | class JSONEncoder(json.JSONEncoder): 7 | def default(self, obj): 8 | if isinstance(obj, np.integer): 9 | return int(obj) 10 | elif isinstance(obj, np.floating): 11 | return float(obj) 12 | elif isinstance(obj, np.ndarray): 13 | return obj.tolist() 14 | else: 15 | return super(JSONEncoder, self).default(obj) 16 | 17 | 18 | def is_jsonable(x, json_encoder=None): 19 | try: 20 | json.dumps(x, cls=json_encoder) 21 | return True 22 | except Exception: 23 | return False 24 | 25 | 26 | def filter_jsonable(data: Dict, json_encoder=None) -> Dict: 27 | return {k: v for k, v in data.items() if is_jsonable(k, json_encoder=json_encoder) and is_jsonable(v, json_encoder=json_encoder)} -------------------------------------------------------------------------------- /utils/Config.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode as _CfgNode 2 | 3 | 4 | class CfgNode(_CfgNode): 5 | """ 6 | The same as `fvcore.common.config.CfgNode`, but different in: 7 | 8 | 1. Use unsafe yaml loading by default. 9 | Note that this may lead to arbitrary code execution: you must not 10 | load a config file from untrusted sources before manually inspecting 11 | the content of the file. 12 | 2. Support config versioning. 13 | When attempting to merge an old config, it will convert the old config automatically. 14 | 15 | .. automethod:: clone 16 | .. automethod:: freeze 17 | .. automethod:: defrost 18 | .. automethod:: is_frozen 19 | .. automethod:: load_yaml_with_base 20 | .. automethod:: merge_from_list 21 | .. automethod:: merge_from_other_cfg 22 | """ 23 | 24 | def merge_from_dict(self, dict): 25 | pass 26 | 27 | node = CfgNode() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_engineering import * 2 | from .dataset import * -------------------------------------------------------------------------------- /utils/arguments.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import argparse 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def load_config_dict_to_opt(opt, config_dict): 10 | """ 11 | Load the key, value pairs from config_dict to opt, overriding existing values in opt 12 | if there is any. 13 | """ 14 | if not isinstance(config_dict, dict): 15 | raise TypeError("Config must be a Python dictionary") 16 | for k, v in config_dict.items(): 17 | k_parts = k.split('.') 18 | pointer = opt 19 | for k_part in k_parts[:-1]: 20 | if k_part not in pointer: 21 | pointer[k_part] = {} 22 | pointer = pointer[k_part] 23 | assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." 24 | ori_value = pointer.get(k_parts[-1]) 25 | pointer[k_parts[-1]] = v 26 | if ori_value: 27 | logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") 28 | 29 | 30 | def load_opt_from_config_files(conf_files): 31 | """ 32 | Load opt from the config files, settings in later files can override those in previous files. 33 | 34 | Args: 35 | conf_files (list): a list of config file paths 36 | 37 | Returns: 38 | dict: a dictionary of opt settings 39 | """ 40 | opt = {} 41 | for conf_file in conf_files: 42 | with open(conf_file, encoding='utf-8') as f: 43 | config_dict = yaml.safe_load(f) 44 | 45 | load_config_dict_to_opt(opt, config_dict) 46 | 47 | return opt 48 | 49 | 50 | def load_opt_command(args): 51 | parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.') 52 | parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') 53 | parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).') 54 | parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.') 55 | parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') 56 | parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER) 57 | 58 | cmdline_args = parser.parse_args() if not args else parser.parse_args(args) 59 | 60 | opt = load_opt_from_config_files(cmdline_args.conf_files) 61 | 62 | if cmdline_args.config_overrides: 63 | config_overrides_string = ' '.join(cmdline_args.config_overrides) 64 | logger.warning(f"Command line config overrides: {config_overrides_string}") 65 | config_dict = json.loads(config_overrides_string) 66 | load_config_dict_to_opt(opt, config_dict) 67 | 68 | if cmdline_args.overrides: 69 | assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value" 70 | keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] 71 | vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] 72 | vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] 73 | 74 | types = [] 75 | for key in keys: 76 | key = key.split('.') 77 | ele = opt.copy() 78 | while len(key) > 0: 79 | ele = ele[key.pop(0)] 80 | types.append(type(ele)) 81 | 82 | config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} 83 | load_config_dict_to_opt(opt, config_dict) 84 | 85 | # combine cmdline_args into opt dictionary 86 | for key, val in cmdline_args.__dict__.items(): 87 | if val is not None: 88 | opt[key] = val 89 | 90 | return opt, cmdline_args -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | class Entity(object): 3 | def __init__(self, _id, _text, _mask, _interactive, _type, _start_idx, _end_idx, _image=None): 4 | self.id = _id 5 | self.text = _text 6 | self.mask = _mask 7 | self.interactive = _interactive 8 | self.type = _type 9 | self.start_idx = _start_idx 10 | self.end_idx = _end_idx 11 | 12 | self.image = _image 13 | 14 | def split_by_ordered_substrings(sentence, substrings): 15 | results = [] 16 | substring_indices = [] 17 | 18 | start_index = 0 19 | for i, substring in enumerate(substrings): 20 | # Find the start of the substring in the remaining part of the sentence 21 | index = sentence[start_index:].find(substring) 22 | 23 | if index == -1: 24 | continue 25 | 26 | # Append any text before the substring to the results, including spaces 27 | if index > 0: 28 | results.append(sentence[start_index:start_index+index]) 29 | substring_indices.append(None) # No match in the `substrings` list for this segment 30 | 31 | # Append the substring to the results 32 | results.append(substring) 33 | substring_indices.append(i) # Append the index from the `substrings` list 34 | start_index += index + len(substring) 35 | 36 | # If there's any remaining part of the sentence after all substrings, append it to the results 37 | if start_index < len(sentence): 38 | results.append(sentence[start_index:]) 39 | substring_indices.append(None) # No match in the `substrings` list for this segment 40 | 41 | return results, substring_indices 42 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import pickle 5 | import subprocess 6 | 7 | from mpi4py import MPI 8 | import torch.distributed as dist 9 | 10 | 11 | def apply_distributed(opt): 12 | if opt['rank'] == 0: 13 | hostname_cmd = ["hostname -I"] 14 | result = subprocess.check_output(hostname_cmd, shell=True) 15 | master_address = result.decode('utf-8').split()[0] 16 | master_port = opt['PORT'] 17 | else: 18 | master_address = None 19 | master_port = None 20 | 21 | master_address = MPI.COMM_WORLD.bcast(master_address, root=0) 22 | master_port = MPI.COMM_WORLD.bcast(master_port, root=0) 23 | 24 | if torch.distributed.is_available() and opt['world_size'] > 1: 25 | init_method_url = 'tcp://{}:{}'.format(master_address, master_port) 26 | backend = 'nccl' 27 | world_size = opt['world_size'] 28 | rank = opt['rank'] 29 | torch.distributed.init_process_group(backend=backend, 30 | init_method=init_method_url, 31 | world_size=world_size, 32 | rank=rank) 33 | 34 | def init_distributed(opt): 35 | opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() 36 | if 'OMPI_COMM_WORLD_SIZE' not in os.environ: 37 | # application was started without MPI 38 | # default to single node with single process 39 | opt['env_info'] = 'no MPI' 40 | opt['world_size'] = 1 41 | opt['local_size'] = 1 42 | opt['rank'] = 0 43 | opt['local_rank'] = 0 44 | opt['master_address'] = '127.0.0.1' 45 | opt['master_port'] = '8673' 46 | else: 47 | # application was started with MPI 48 | # get MPI parameters 49 | opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) 50 | opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) 51 | opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) 52 | opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 53 | 54 | # set up device 55 | if not opt['CUDA']: 56 | assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend' 57 | opt['device'] = torch.device("cpu") 58 | else: 59 | torch.cuda.set_device(opt['local_rank']) 60 | opt['device'] = torch.device("cuda", opt['local_rank']) 61 | 62 | apply_distributed(opt) 63 | return opt 64 | 65 | def is_main_process(): 66 | rank = 0 67 | if 'OMPI_COMM_WORLD_SIZE' in os.environ: 68 | rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 69 | 70 | return rank == 0 71 | 72 | def get_world_size(): 73 | if not dist.is_available(): 74 | return 1 75 | if not dist.is_initialized(): 76 | return 1 77 | return dist.get_world_size() 78 | 79 | def get_rank(): 80 | if not dist.is_available(): 81 | return 0 82 | if not dist.is_initialized(): 83 | return 0 84 | return dist.get_rank() 85 | 86 | 87 | def synchronize(): 88 | """ 89 | Helper function to synchronize (barrier) among all processes when 90 | using distributed training 91 | """ 92 | if not dist.is_available(): 93 | return 94 | if not dist.is_initialized(): 95 | return 96 | world_size = dist.get_world_size() 97 | rank = dist.get_rank() 98 | if world_size == 1: 99 | return 100 | 101 | def _send_and_wait(r): 102 | if rank == r: 103 | tensor = torch.tensor(0, device="cuda") 104 | else: 105 | tensor = torch.tensor(1, device="cuda") 106 | dist.broadcast(tensor, r) 107 | while tensor.item() == 1: 108 | time.sleep(1) 109 | 110 | _send_and_wait(0) 111 | # now sync on the main process 112 | _send_and_wait(1) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu) 6 | # -------------------------------------------------------- 7 | import math 8 | 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value.""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1, decay=0): 23 | self.val = val 24 | if decay: 25 | alpha = math.exp(-n / decay) # exponential decay over 100 updates 26 | self.sum = alpha * self.sum + (1 - alpha) * val * n 27 | self.count = alpha * self.count + (1 - alpha) * n 28 | else: 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import pickle 5 | import torch 6 | import torch.nn as nn 7 | 8 | from utils.distributed import is_main_process 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | NORM_MODULES = [ 14 | torch.nn.BatchNorm1d, 15 | torch.nn.BatchNorm2d, 16 | torch.nn.BatchNorm3d, 17 | torch.nn.SyncBatchNorm, 18 | # NaiveSyncBatchNorm inherits from BatchNorm2d 19 | torch.nn.GroupNorm, 20 | torch.nn.InstanceNorm1d, 21 | torch.nn.InstanceNorm2d, 22 | torch.nn.InstanceNorm3d, 23 | torch.nn.LayerNorm, 24 | torch.nn.LocalResponseNorm, 25 | ] 26 | 27 | def register_norm_module(cls): 28 | NORM_MODULES.append(cls) 29 | return cls 30 | 31 | def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): 32 | model_keys = sorted(model_state_dict.keys()) 33 | ckpt_keys = sorted(ckpt_state_dict.keys()) 34 | result_dicts = {} 35 | matched_log = [] 36 | unmatched_log = [] 37 | unloaded_log = [] 38 | for model_key in model_keys: 39 | model_weight = model_state_dict[model_key] 40 | if model_key in ckpt_keys: 41 | ckpt_weight = ckpt_state_dict[model_key] 42 | if model_weight.shape == ckpt_weight.shape: 43 | result_dicts[model_key] = ckpt_weight 44 | ckpt_keys.pop(ckpt_keys.index(model_key)) 45 | matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 46 | else: 47 | unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) 48 | else: 49 | unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) 50 | 51 | if is_main_process(): 52 | for info in matched_log: 53 | logger.info(info) 54 | for info in unloaded_log: 55 | logger.warning(info) 56 | for key in ckpt_keys: 57 | logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) 58 | for info in unmatched_log: 59 | logger.warning(info) 60 | return result_dicts -------------------------------------------------------------------------------- /utils/prompt_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_prompt_templates(): 5 | prompt_templates = [ 6 | '{}.', 7 | 'a photo of a {}.', 8 | 'a bad photo of a {}.', 9 | 'a photo of many {}.', 10 | 'a sculpture of a {}.', 11 | 'a photo of the hard to see {}.', 12 | 'a low resolution photo of the {}.', 13 | 'a rendering of a {}.', 14 | 'graffiti of a {}.', 15 | 'a bad photo of the {}.', 16 | 'a cropped photo of the {}.', 17 | 'a tattoo of a {}.', 18 | 'the embroidered {}.', 19 | 'a photo of a hard to see {}.', 20 | 'a bright photo of a {}.', 21 | 'a photo of a clean {}.', 22 | 'a photo of a dirty {}.', 23 | 'a dark photo of the {}.', 24 | 'a drawing of a {}.', 25 | 'a photo of my {}.', 26 | 'the plastic {}.', 27 | 'a photo of the cool {}.', 28 | 'a close-up photo of a {}.', 29 | 'a black and white photo of the {}.', 30 | 'a painting of the {}.', 31 | 'a painting of a {}.', 32 | 'a pixelated photo of the {}.', 33 | 'a sculpture of the {}.', 34 | 'a bright photo of the {}.', 35 | 'a cropped photo of a {}.', 36 | 'a plastic {}.', 37 | 'a photo of the dirty {}.', 38 | 'a jpeg corrupted photo of a {}.', 39 | 'a blurry photo of the {}.', 40 | 'a photo of the {}.', 41 | 'a good photo of the {}.', 42 | 'a rendering of the {}.', 43 | 'a {} in a video game.', 44 | 'a photo of one {}.', 45 | 'a doodle of a {}.', 46 | 'a close-up photo of the {}.', 47 | 'the origami {}.', 48 | 'the {} in a video game.', 49 | 'a sketch of a {}.', 50 | 'a doodle of the {}.', 51 | 'a origami {}.', 52 | 'a low resolution photo of a {}.', 53 | 'the toy {}.', 54 | 'a rendition of the {}.', 55 | 'a photo of the clean {}.', 56 | 'a photo of a large {}.', 57 | 'a rendition of a {}.', 58 | 'a photo of a nice {}.', 59 | 'a photo of a weird {}.', 60 | 'a blurry photo of a {}.', 61 | 'a cartoon {}.', 62 | 'art of a {}.', 63 | 'a sketch of the {}.', 64 | 'a embroidered {}.', 65 | 'a pixelated photo of a {}.', 66 | 'itap of the {}.', 67 | 'a jpeg corrupted photo of the {}.', 68 | 'a good photo of a {}.', 69 | 'a plushie {}.', 70 | 'a photo of the nice {}.', 71 | 'a photo of the small {}.', 72 | 'a photo of the weird {}.', 73 | 'the cartoon {}.', 74 | 'art of the {}.', 75 | 'a drawing of the {}.', 76 | 'a photo of the large {}.', 77 | 'a black and white photo of a {}.', 78 | 'the plushie {}.', 79 | 'a dark photo of a {}.', 80 | 'itap of a {}.', 81 | 'graffiti of the {}.', 82 | 'a toy {}.', 83 | 'itap of my {}.', 84 | 'a photo of a cool {}.', 85 | 'a photo of a small {}.', 86 | 'a tattoo of the {}.', 87 | ] 88 | return prompt_templates 89 | 90 | def prompt_engineering(classnames, topk=1, suffix='.'): 91 | prompt_templates = get_prompt_templates() 92 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 93 | 94 | if isinstance(classnames, list): 95 | classname = random.choice(classnames) 96 | else: 97 | classname = classnames 98 | 99 | return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) --------------------------------------------------------------------------------