├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── __init__.py
├── assets
├── readmes
│ ├── DATASET.md
│ ├── EVAL.md
│ ├── INFERENCE.md
│ ├── INSTALL.md
│ └── TRAIN.md
├── requirements
│ ├── requirements.txt
│ └── requirements_custom.txt
└── scripts
│ └── run_demo.sh
├── configs
├── seem
│ ├── davitd3_unicl_lang_v1.yaml
│ ├── davitd5_unicl_lang_v1.yaml
│ ├── focall_unicl_lang_demo.yaml
│ ├── focall_unicl_lang_v0.yaml
│ ├── focall_unicl_lang_v1.yaml
│ ├── focalt_unicl_lang_demo.yaml
│ ├── focalt_unicl_lang_v0.yaml
│ ├── focalt_unicl_lang_v1.yaml
│ ├── samvitb_unicl_lang_v1.yaml
│ └── samvitl_unicl_lang_v1.yaml
└── xdecoder
│ ├── davitd3_unicl_lang.yaml
│ ├── davitd5_unicl_lang.yaml
│ ├── focall_unicl_lang.yaml
│ ├── focalt_unicl_lang.yaml
│ ├── xdecoder_focall_lang.yaml
│ └── xdecoder_focalt_lang.yaml
├── datasets
├── __init__.py
├── build.py
├── dataset_mappers
│ ├── __init__.py
│ ├── bdd_semseg_dataset_mapper.py
│ ├── coco_instance_new_baseline_dataset_mapper.py
│ ├── coco_panoptic_interactive_dataset_mapper.py
│ ├── coco_panoptic_new_baseline_dataset_mapper.py
│ ├── imagenet_dataset_mapper.py
│ ├── mask_former_instance_dataset_mapper.py
│ ├── mask_former_panoptic_dataset_mapper.py
│ ├── mask_former_semantic_dataset_mapper.py
│ ├── pascalvoc_dataset_mapper_ix.py
│ ├── refcoco_dataset_mapper.py
│ ├── scannet_dataset_mapper.py
│ ├── scannet_pano_dataset_mapper.py
│ ├── sunrgbd_dataset_mapper.py
│ └── vlp_dataset_mapper.py
├── evaluation
│ ├── __init__.py
│ ├── captioning_evaluation.py
│ ├── classification_evaluation.py
│ ├── grounding_evaluation.py
│ ├── instance_evaluation.py
│ ├── interactive_evaluation.py
│ ├── panoptic_evaluation.py
│ ├── retrieval_evaluation.py
│ └── segmentation_evaluation.py
├── refer.py
├── registration
│ ├── __init__.py
│ ├── register_ade20k_full.py
│ ├── register_ade20k_instance.py
│ ├── register_ade20k_panoptic.py
│ ├── register_bdd100k_panoseg.py
│ ├── register_bdd100k_semseg.py
│ ├── register_coco_lvis_panoptic_annos_caption_grounding.py
│ ├── register_coco_panoptic_annos_caption.py
│ ├── register_coco_panoptic_annos_caption_grounding.py
│ ├── register_coco_panoptic_annos_semseg.py
│ ├── register_coco_stuff_10k.py
│ ├── register_imagenet_cls.py
│ ├── register_pascalvoc_eval.py
│ ├── register_refcoco_dataset.py
│ ├── register_scannet_panoptic.py
│ ├── register_scannet_semseg.py
│ ├── register_sunrgbd_semseg.py
│ └── register_vlp_datasets.py
├── semseg_loader.py
├── utils
│ ├── refcoco2json.py
│ └── refer.py
└── visual_sampler
│ ├── __init__.py
│ ├── circle.py
│ ├── mask_generators.py
│ ├── point.py
│ ├── polygon.py
│ ├── sampler.py
│ ├── scribble.py
│ └── simpleclick_sampler.py
├── demo
├── __init__.py
└── seem
│ ├── __init__.py
│ ├── app.py
│ ├── examples
│ ├── corgi1.webp
│ ├── corgi2.jpg
│ ├── fries1.png
│ ├── fries2.png
│ ├── minecraft1.jpg
│ ├── placeholder.png
│ ├── ref_vase.JPG
│ ├── river1.png
│ ├── river1.wav
│ ├── river1_mask.png
│ ├── river2.png
│ ├── vasedeck.mp4
│ ├── zebras1.jpg
│ └── zebras2.jpg
│ └── tasks
│ ├── __init__.py
│ └── interactive.py
├── entry.py
├── inference
├── __init__.py
├── images
│ ├── animals.png
│ ├── apples.jpg
│ ├── coco
│ │ ├── 000.jpg
│ │ ├── 001.jpg
│ │ ├── 002.jpg
│ │ └── 003.jpg
│ ├── fruit.jpg
│ ├── landscape.jpg
│ ├── mountain.jpeg
│ ├── owls.jpeg
│ ├── penguin.jpeg
│ ├── region_retrieval.png
│ ├── rose.webp
│ ├── street.jpg
│ └── teaser_new.png
└── xdecoder
│ ├── infer_captioning.py
│ ├── infer_instseg.py
│ ├── infer_panoseg.py
│ ├── infer_refseg.py
│ ├── infer_region_retrieval.py
│ └── infer_semseg.py
├── modeling
├── BaseModel.py
├── __init__.py
├── architectures
│ ├── __init__.py
│ ├── build.py
│ ├── registry.py
│ ├── seem_model_demo.py
│ ├── seem_model_v0.py
│ ├── seem_model_v1.py
│ └── xdecoder_model.py
├── backbone
│ ├── __init__.py
│ ├── backbone.py
│ ├── build.py
│ ├── focal.py
│ ├── focal_dw.py
│ └── registry.py
├── body
│ ├── __init__.py
│ ├── build.py
│ ├── decoder
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── modules.py
│ │ ├── registry.py
│ │ └── xdecoder.py
│ ├── encoder
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── registry.py
│ │ └── transformer_encoder_fpn.py
│ ├── registry.py
│ ├── transformer_blocks.py
│ └── xdecoder_head.py
├── interface
│ ├── __init__.py
│ ├── build.py
│ ├── modules.py
│ ├── prototype
│ │ ├── __init__.py
│ │ ├── attention_data_struct_seemdemo.py
│ │ ├── attention_data_struct_seemv0.py
│ │ └── attention_data_struct_seemv1.py
│ ├── seem_demo.py
│ ├── seem_v0.py
│ ├── seem_v1.py
│ └── xdecoder.py
├── language
│ ├── LangEncoder
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── registry.py
│ │ └── transformer.py
│ ├── __init__.py
│ ├── build.py
│ ├── loss.py
│ ├── misc.py
│ ├── registry.py
│ └── vlpencoder.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── criterion.py
│ ├── matcher.py
│ ├── point_features.py
│ ├── position_encoding.py
│ └── postprocessing.py
├── utils
│ ├── __init__.py
│ ├── attention.py
│ ├── box_ops.py
│ ├── config.py
│ ├── interactive.py
│ └── misc.py
└── vision
│ ├── backbone
│ ├── __init__.py
│ ├── backbone.py
│ ├── build.py
│ ├── common.py
│ ├── davit.py
│ ├── focal.py
│ ├── focal_dw.py
│ └── vit.py
│ └── encoder
│ ├── __init__.py
│ ├── build.py
│ ├── ops
│ ├── functions
│ │ ├── __init__.py
│ │ └── ms_deform_attn_func.py
│ ├── make.sh
│ ├── modules
│ │ ├── __init__.py
│ │ └── ms_deform_attn.py
│ ├── setup.py
│ ├── src
│ │ ├── cpu
│ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ └── ms_deform_attn_cpu.h
│ │ ├── cuda
│ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ ├── ms_deform_attn_cuda.h
│ │ │ └── ms_deform_im2col_cuda.cuh
│ │ ├── ms_deform_attn.h
│ │ └── vision.cpp
│ └── test.py
│ ├── transformer_blocks.py
│ ├── transformer_encoder_deform.py
│ └── transformer_encoder_fpn.py
├── pipeline
├── XDecoderPipeline.py
├── __init__.py
└── utils
│ └── misc.py
├── trainer
├── __init__.py
├── default_trainer.py
├── distributed_trainer.py
├── utils
│ ├── __init__.py
│ ├── hook.py
│ ├── misc.py
│ ├── mpi_adapter.py
│ └── serialization.py
├── utils_trainer.py
└── xdecoder_trainer.py
└── utils
├── Config.py
├── __init__.py
├── arguments.py
├── constants.py
├── dataset.py
├── distributed.py
├── misc.py
├── model.py
├── prompt_engineering.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # IntelliJ project files
2 | .idea
3 | *.iml
4 | out
5 | gen
6 |
7 | ### Vim template
8 | [._]*.s[a-w][a-z]
9 | [._]s[a-w][a-z]
10 | *.un~
11 | Session.vim
12 | .netrwhist
13 | *~
14 |
15 | ### IPythonNotebook template
16 | # Temporary data
17 | .ipynb_checkpoints/
18 |
19 | ### Python template
20 | # Byte-compiled / optimized / DLL files
21 | __pycache__/
22 | *.py[cod]
23 | *$py.class
24 |
25 | # C extensions
26 | *.so
27 |
28 | # Distribution / packaging
29 | .Python
30 | env/
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | #lib/
38 | #lib64/
39 | parts/
40 | sdist/
41 | var/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 |
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *,cover
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | *.ipynb
80 | *.params
81 | # *.json
82 | .vscode/
83 | *.code-workspace/
84 |
85 | lib/pycocotools/_mask.c
86 | lib/nms/cpu_nms.c
87 |
88 | OUTPUT
89 | OUTPUT/*
90 | models/*
91 | DATASET
92 | DATASET/*
93 | external/
94 | MODELS
95 | MODELS/*
96 |
97 | kill.sh
98 | train.sh
99 |
100 | draws/
101 | plot/
102 |
103 | *run.sh
104 | exps/*
105 | amlt/*
106 |
107 | *venv/*
108 | *.pt
109 | *.pth
110 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/__init__.py
--------------------------------------------------------------------------------
/assets/readmes/DATASET.md:
--------------------------------------------------------------------------------
1 | # Preparing Dataset
2 |
3 | :bangbang: The dataset preparation contains many details, welcome community contribution to fix any bug, Thanks!
4 |
5 | Our dataloader follows [Detectron2](https://github.com/facebookresearch/detectron2) that contains:
6 | (1) [A dataset registrator](datasets/registration)
7 | (2) [A dataset mapper](datasets/dataset_mappers)
8 | We modify the dataset registration and mapper for custom datasets.
9 |
10 | ## Training Dataset
11 | We assume all the datasets are stored under:
12 | ```
13 | .xdecoder_data
14 | ```
15 |
16 | ### COCO (SEEM & X-Decoder)
17 |
18 | ```sh
19 | # Prepare panoptic_train2017, panoptic_semseg_train2017 exactly the same as [Mask2Fomer](https://github.com/facebookresearch/Mask2Former/tree/main/datasets)
20 |
21 | # (SEEM & X-Decoder) Download additional logistic and custom annotation files to .xdecoder_data/coco/annotations
22 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/caption_class_similarity.pth
23 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/captions_train2017_filtrefgumdval_filtvlp.json
24 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/grounding_train2017_filtrefgumdval_filtvlp.json
25 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/panoptic_train2017_filtrefgumdval_filtvlp.json
26 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/refcocog_umd_val.json
27 | wget https://github.com/peteanderson80/coco-caption/blob/master/annotations/captions_val2014.json
28 |
29 | # (SEEM) Download LVIS annotations for mask preparation
30 | wget https://huggingface.co/xdecoder/SEEM/resolve/main/coco_train2017_filtrefgumdval_lvis.json
31 | ```
32 |
33 | After dataset preparation, the dataset structure would be:
34 | ```
35 | .xdecoder_data
36 | └── coco/
37 | ├── train2017/
38 | ├── val2017/
39 | ├── panoptic_train2017/
40 | ├── panoptic_semseg_train2017/
41 | ├── panoptic_val2017/
42 | ├── panoptic_semseg_val2017/
43 | └── annotations/
44 | ├── refcocog_umd_val.json
45 | ├── captions_val2014.json
46 | ├── panoptic_val2017.json
47 | ├── caption_class_similarity.pth
48 | ├── panoptic_train2017_filtrefgumdval_filtvlp.json
49 | └── grounding_train2017_filtrefgumdval_filtvlp.json
50 | └── lvis/
51 | └── coco_train2017_filtrefgumdval_lvis.json
52 | ```
53 |
54 | #### 4M Image Text Pairs (X-Decoder)
55 | We follow the exact data preparation for the image text pairs data with [ViLT](https://github.com/dandelin/ViLT/blob/master/DATA.md).
56 | ```
57 | # The pretrained arrow file are put under .xdecoder_data/pretrain_arrows_code224 with the following list of files.
58 | ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] + ["code224_vg.arrow"] + [f"code224_sbu_{i}.arrow" for i in range(9)] + [f"code224_conceptual_caption_train_{i}.arrow" for i in range(31)]
59 | # ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] are originated from ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] with deletion of coco val2017 overlapped images to avoid information leakage.
60 | ```
61 |
62 | To get quick started:
63 | ```sh
64 | # Download coco karparthy test set (we hack the training data to be coco_caption_karpathy_test.arrow only for quick start in the codebase)
65 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/coco_caption_karpathy_test.arrow
66 | ```
67 |
68 | After dataset preparation, the dataset structure would be:
69 | ```
70 | .xdecoder_data
71 | └── pretrain_arrows_code224/
72 | ├── coco_caption_karpathy_test.arrow
73 | ├── *filtcoco2017val_caption_karpathy_train.arrow
74 | ├── ...
75 | ├── *code224_vg.arrow
76 | ├── *code224_sbu_0.arrow
77 | ├── ...
78 | ├── *code224_conceptual_caption_train_0.arrow
79 | └── ...
80 | * Those datasets are optional for debugging the pipeline. ! NEED to add back when you are training the model.
81 | ```
82 |
83 | ***NOTE:***
84 |
85 |
86 | There are overlap between COCO2017, COCO-Karpathy and REF-COCO dataset, and ref-coco is all overlapped with the COCO2017 training data, we have exclude the refcocog-umd validation, coco-karpathy test split during training.
87 |
88 | ## Evaluation Dataset
89 |
90 | ### RefCOCO (SEEM & X-Decoder)
91 | Please refer to COCO Preparation on [line](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/blob/v1.0/assets/readmes/DATASET.md#coco-seem--x-decoder).
92 |
93 | ### ADE20K, Cityscapes (X-Decoder)
94 | Please Refer to [Mask2Former](https://github.com/facebookresearch/Mask2Former/tree/main/datasets).
95 |
96 | ### BDD100K (X-Decoder)
97 | Please download the 10k split of BDD100k at https://doc.bdd100k.com/download.html#id1
98 |
99 | ### PascalVOC and all other interactive evaluation datasets (SEEM)
100 | Please follow the instruction on [RITM](https://github.com/SamsungLabs/ritm_interactive_segmentation)
101 |
102 | After dataset preparation, the dataset structure would be:
103 | ```
104 | .xdecoder_data
105 | └── PascalVOC/
106 | ├── Annotations/
107 | ├── ImageSets
108 | ├── JPEGImages/
109 | ├── SegmentationClass/
110 | └── SegmentationObject/
111 | ```
112 |
113 |
--------------------------------------------------------------------------------
/assets/readmes/INFERENCE.md:
--------------------------------------------------------------------------------
1 | ### Demo
2 |
3 | OpenVocab Semantic Segmentation
4 | ```sh
5 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_semseg.py evaluate \
6 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \
7 | --overrides \
8 | RESUME_FROM /pth/to/xdecoder_focall_best_openseg.pt
9 | ```
10 |
11 | OpenVocab Instance Segmentation
12 | ```sh
13 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_instseg.py evaluate \
14 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \
15 | --overrides \
16 | RESUME_FROM /pth/to/xdecoder_focall_best_openseg.pt
17 | ```
18 |
19 | OpenVocab Panoptic Segmentation
20 | ```sh
21 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_panoseg.py evaluate \
22 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \
23 | --overrides \
24 | RESUME_FROM /pth/to/xdecoder_focall_best_openseg.pt
25 | ```
26 |
27 | OpenVocab Referring Segmentation
28 | ```sh
29 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_refseg.py evaluate \
30 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \
31 | --overrides \
32 | RESUME_FROM /pth/to/xdecoder_focall_last.pt
33 | ```
34 |
35 | Region Retrieval
36 | ```sh
37 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_region_retrieval.py evaluate \
38 | --conf_files configs/xdecoder/xdecoder_focall_lang.yaml \
39 | --overrides \
40 | RESUME_FROM /pth/to/xdecoder_focall_last.pt
41 | ```
42 |
43 | Image Captioning
44 | ```sh
45 | CUDA_VISIBLE_DEVICES=0 python inference/xdecoder/infer_captioning.py evaluate \
46 | --conf_files configs/xdecoder/xdecoder_focalt_lang.yaml \
47 | --overrides \
48 | RESUME_FROM /pth/to/xdecoder_focalt_last_novg.pt
49 | ```
--------------------------------------------------------------------------------
/assets/readmes/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation Guide
2 |
3 | **General Environment**
4 | * Linux System
5 | * CUDA enabled GPU with Memory > 8GB (Evaluation)
6 | * CUDA enabled GPU with Memory > 12GB (Training)
7 |
8 | **Installation**
9 |
10 | ```sh
11 | # Python Package Installation
12 | pip install -r assets/requirements/requirements.txt
13 | pip install -r assets/requirements/requirements_custom.txt
14 |
15 | # Customer Operator [only need training deformable vision encoder]
16 | cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../
17 |
18 | # System Package [only need for demo in SEEM]
19 | sudo apt update
20 | sudo apt install ffmpeg
21 | ```
22 |
23 | **Dataset Preparation**
24 |
25 | Please refer to [DATASET.md](assets/readmes/DATASET.md).
26 |
27 | **Evaluation Tool**
28 | ```sh
29 | # save coco_caption.zip to .xdecoder_data
30 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/coco_caption.zip
31 | unzip coco_caption.zip
32 | ```
33 |
34 | **Environment Variables**
35 | ```sh
36 | export DETECTRON2_DATASETS=/pth/to/xdecoder_data
37 | export DATASET=/pth/to/xdecoder_data
38 | export DATASET2=/pth/to/xdecoder_data
39 | export VLDATASET=/pth/to/xdecoder_data
40 | export PATH=$PATH:/pth/to/xdecoder_data/coco_caption/jre1.8.0_321/bin
41 | export PYTHONPATH=$PYTHONPATH:/pth/to/xdecoder_data/coco_caption
42 | ```
43 |
44 | **Pretrained Checkpoint**
45 |
46 | X-Decoder:
47 | ```sh
48 | # Focal-T UniCL
49 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focalt_in21k_yfcc_gcc_xdecoder_unicl.pt
50 |
51 | # Focal-L UniCL
52 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focall_vision_focalb_lang_unicl.pt
53 | ```
54 |
55 | SEEM:
56 | ```
57 | # Focal-T X-Decoder
58 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_last.pt
59 |
60 | # Focal-L X-Decoder
61 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focall_last_oq101.pt
62 |
63 | # Focal-B UniCL Language
64 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focalb_lang_unicl.pt
65 |
66 | # ViT-B SAM
67 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
68 |
69 | # ViT-L SAM
70 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
71 |
72 | ```
73 |
74 |
--------------------------------------------------------------------------------
/assets/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.0
2 | torchvision==0.16.0
3 | pillow==9.4.0
4 | opencv-python==4.8.1.78
5 | pyyaml==6.0.1
6 | json_tricks==3.17.3
7 | yacs==0.1.8
8 | scikit-learn==1.3.1
9 | pandas==2.0.3
10 | timm==0.4.12
11 | numpy==1.23.1
12 | einops==0.7.0
13 | fvcore==0.1.5.post20221221
14 | transformers==4.34.0
15 | sentencepiece==0.1.99
16 | ftfy==6.1.1
17 | regex==2023.10.3
18 | nltk==3.8.1
19 | mpi4py==3.1.5
20 | vision-datasets==0.2.2
21 | cython==3.0.2
22 | pycocotools==2.0.7
23 | diffdist==0.1
24 | pyarrow==13.0.0
25 | cityscapesscripts==2.2.2
26 | shapely==1.8.0
27 | scikit-image==0.21.0
28 | mup==1.0.0
29 | accelerate==0.23.0
30 | kornia==0.7.0
31 | deepspeed==0.10.3
32 | wandb==0.15.12
33 | infinibatch==0.1.1
34 | gradio==3.42.0
--------------------------------------------------------------------------------
/assets/requirements/requirements_custom.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/arogozhnikov/einops.git
2 | git+https://github.com/MaureenZOU/detectron2-xyz.git
3 | git+https://github.com/openai/whisper.git
--------------------------------------------------------------------------------
/assets/scripts/run_demo.sh:
--------------------------------------------------------------------------------
1 | sudo apt update
2 | sudo apt install ffmpeg
3 | pip install -r assets/requirements/requirements.txt
4 | pip install -r assets/requirements/requirements_custom.txt
5 | python demo/seem/demo.py
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import registration
2 | from .build import build_train_dataloader, build_eval_dataloader, build_evaluator
--------------------------------------------------------------------------------
/datasets/dataset_mappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper
2 | from .coco_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper
3 | from .coco_panoptic_new_baseline_dataset_mapper import COCOPanopticNewBaselineDatasetMapper
4 | from .mask_former_instance_dataset_mapper import MaskFormerInstanceDatasetMapper
5 | from .mask_former_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapper
6 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
7 | from .imagenet_dataset_mapper import ImageNetDatasetMapper
8 | from .vlp_dataset_mapper import VLPreDatasetMapper
9 | from .sunrgbd_dataset_mapper import SunRGBDSegDatasetMapper
10 | from .scannet_dataset_mapper import ScanNetSegDatasetMapper
11 | from .bdd_semseg_dataset_mapper import BDDSemDatasetMapper
12 | from .scannet_pano_dataset_mapper import ScanNetPanoDatasetMapper
13 | from .refcoco_dataset_mapper import RefCOCODatasetMapper
14 | from .pascalvoc_dataset_mapper_ix import PascalVOCSegDatasetMapperIX
--------------------------------------------------------------------------------
/datasets/dataset_mappers/bdd_semseg_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates.
8 | import copy
9 |
10 | import scipy.io
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 |
15 | from torchvision import transforms
16 | from modeling.utils import configurable
17 |
18 | __all__ = ["BDDSemDatasetMapper"]
19 |
20 |
21 | # This is specifically designed for the COCO dataset.
22 | class BDDSemDatasetMapper:
23 | """
24 | A callable which takes a dataset dict in Detectron2 Dataset format,
25 | and map it into a format used by MaskFormer.
26 |
27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
28 |
29 | The callable currently does the following:
30 |
31 | 1. Read the image from "file_name"
32 | 2. Applies geometric transforms to the image and annotation
33 | 3. Find and applies suitable cropping to the image and annotation
34 | 4. Prepare image and annotation to Tensors
35 | """
36 |
37 | @configurable
38 | def __init__(
39 | self,
40 | is_train=True,
41 | min_size_test=None,
42 | max_size_test=None,
43 | mean=None,
44 | std=None,
45 | ):
46 | """
47 | NOTE: this interface is experimental.
48 | Args:
49 | is_train: for training or inference
50 | augmentations: a list of augmentations or deterministic transforms to apply
51 | tfm_gens: data augmentation
52 | image_format: an image format supported by :func:`detection_utils.read_image`.
53 | """
54 | self.is_train = is_train
55 | self.min_size_test = min_size_test
56 | self.max_size_test = max_size_test
57 | self.pixel_mean = torch.tensor(mean)[:,None,None]
58 | self.pixel_std = torch.tensor(std)[:,None,None]
59 |
60 | t = []
61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
62 | self.transform = transforms.Compose(t)
63 |
64 | @classmethod
65 | def from_config(cls, cfg, is_train=True):
66 | ret = {
67 | "is_train": is_train,
68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
70 | "mean": cfg['INPUT']['PIXEL_MEAN'],
71 | "std": cfg['INPUT']['PIXEL_STD'],
72 | }
73 | return ret
74 |
75 | def read_semseg(self, file_name):
76 | if '.png' in file_name:
77 | semseg = np.asarray(Image.open(file_name))
78 | elif '.mat' in file_name:
79 | semseg = scipy.io.loadmat(file_name)['LabelMap']
80 | return semseg
81 |
82 | def __call__(self, dataset_dict):
83 | """
84 | Args:
85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
86 |
87 | Returns:
88 | dict: a format that builtin models in detectron2 accept
89 | """
90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
91 | file_name = dataset_dict['file_name']
92 | semseg_name = dataset_dict['sem_seg_file_name']
93 | image = Image.open(file_name).convert('RGB')
94 |
95 | dataset_dict['width'] = image.size[0]
96 | dataset_dict['height'] = image.size[1]
97 |
98 | if self.is_train == False:
99 | image = self.transform(image)
100 | image = torch.from_numpy(np.asarray(image).copy())
101 | image = image.permute(2,0,1)
102 |
103 | semseg = self.read_semseg(semseg_name)
104 | semseg = torch.from_numpy(semseg.astype(np.int32))
105 | dataset_dict['image'] = image
106 | dataset_dict['semseg'] = semseg
107 | return dataset_dict
--------------------------------------------------------------------------------
/datasets/dataset_mappers/imagenet_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates.
8 | import copy
9 | from PIL import Image
10 | # import logging
11 |
12 | import cv2
13 | import numpy as np
14 |
15 | import torch
16 | from torchvision import transforms
17 |
18 | from modeling.utils import configurable
19 |
20 | __all__ = ["ImageNetDatasetMapper"]
21 |
22 |
23 | # This is specifically designed for the COCO dataset.
24 | class ImageNetDatasetMapper:
25 | """
26 | A callable which takes a dataset dict in Detectron2 Dataset format,
27 | and map it into a format used by MaskFormer.
28 |
29 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
30 |
31 | The callable currently does the following:
32 |
33 | 1. Read the image from "file_name"
34 | 2. Applies geometric transforms to the image and annotation
35 | 3. Find and applies suitable cropping to the image and annotation
36 | 4. Prepare image and annotation to Tensors
37 | """
38 |
39 | @configurable
40 | def __init__(
41 | self,
42 | is_train=True,
43 | size_train=None,
44 | size_test=None,
45 | size_crop=None,
46 | ):
47 | """
48 | NOTE: this interface is experimental.
49 | Args:
50 | is_train: for training or inference
51 | augmentations: a list of augmentations or deterministic transforms to apply
52 | tfm_gens: data augmentation
53 | image_format: an image format supported by :func:`detection_utils.read_image`.
54 | """
55 | self.is_train = is_train
56 | self.size_train = size_train
57 | self.size_test = size_test
58 | self.size_crop = size_crop
59 |
60 | t = []
61 | t.append(transforms.Resize(size_crop, interpolation=Image.BICUBIC))
62 | t.append(transforms.CenterCrop(size_test))
63 | self.transform = transforms.Compose(t)
64 |
65 | @classmethod
66 | def from_config(cls, cfg, is_train=True):
67 | ret = {
68 | "is_train": is_train,
69 | "size_train": cfg['INPUT']['SIZE_TRAIN'],
70 | "size_test": cfg['INPUT']['SIZE_TEST'],
71 | "size_crop": cfg['INPUT']['SIZE_CROP']
72 | }
73 | return ret
74 |
75 | def __call__(self, dataset_dict):
76 | """
77 | Args:
78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
79 |
80 | Returns:
81 | dict: a format that builtin models in detectron2 accept
82 | """
83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
84 | file_name = dataset_dict['file_name']
85 | image = Image.open(file_name).convert('RGB')
86 |
87 | if self.is_train == False:
88 | image = self.transform(image)
89 | image = torch.from_numpy(np.asarray(image).copy())
90 | image = image.permute(2,0,1)
91 |
92 | dataset_dict['image'] = image
93 | dataset_dict['height'] = image.shape[1]
94 | dataset_dict['width'] = image.shape[2]
95 | return dataset_dict
--------------------------------------------------------------------------------
/datasets/dataset_mappers/scannet_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates.
8 | import copy
9 |
10 | import scipy.io
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 |
15 | from torchvision import transforms
16 | from modeling.utils import configurable
17 |
18 | __all__ = ["ScanNetSegDatasetMapper"]
19 |
20 |
21 | # This is specifically designed for the COCO dataset.
22 | class ScanNetSegDatasetMapper:
23 | """
24 | A callable which takes a dataset dict in Detectron2 Dataset format,
25 | and map it into a format used by MaskFormer.
26 |
27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
28 |
29 | The callable currently does the following:
30 |
31 | 1. Read the image from "file_name"
32 | 2. Applies geometric transforms to the image and annotation
33 | 3. Find and applies suitable cropping to the image and annotation
34 | 4. Prepare image and annotation to Tensors
35 | """
36 |
37 | @configurable
38 | def __init__(
39 | self,
40 | is_train=True,
41 | min_size_test=None,
42 | max_size_test=None,
43 | mean=None,
44 | std=None,
45 | ):
46 | """
47 | NOTE: this interface is experimental.
48 | Args:
49 | is_train: for training or inference
50 | augmentations: a list of augmentations or deterministic transforms to apply
51 | tfm_gens: data augmentation
52 | image_format: an image format supported by :func:`detection_utils.read_image`.
53 | """
54 | self.is_train = is_train
55 | self.min_size_test = min_size_test
56 | self.max_size_test = max_size_test
57 | self.pixel_mean = torch.tensor(mean)[:,None,None]
58 | self.pixel_std = torch.tensor(std)[:,None,None]
59 |
60 | t = []
61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
62 | self.transform = transforms.Compose(t)
63 |
64 | @classmethod
65 | def from_config(cls, cfg, is_train=True):
66 | ret = {
67 | "is_train": is_train,
68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
70 | "mean": cfg['INPUT']['PIXEL_MEAN'],
71 | "std": cfg['INPUT']['PIXEL_STD'],
72 | }
73 | return ret
74 |
75 | def read_semseg(self, file_name):
76 | if '.png' in file_name:
77 | semseg = np.asarray(Image.open(file_name))
78 | elif '.mat' in file_name:
79 | semseg = scipy.io.loadmat(file_name)['LabelMap']
80 | return semseg
81 |
82 | def __call__(self, dataset_dict):
83 | """
84 | Args:
85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
86 |
87 | Returns:
88 | dict: a format that builtin models in detectron2 accept
89 | """
90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
91 | file_name = dataset_dict['file_name']
92 | semseg_name = dataset_dict['sem_seg_file_name']
93 | image = Image.open(file_name).convert('RGB')
94 |
95 | dataset_dict['width'] = image.size[0]
96 | dataset_dict['height'] = image.size[1]
97 |
98 | if self.is_train == False:
99 | image = self.transform(image)
100 | image = torch.from_numpy(np.asarray(image).copy())
101 | image = image.permute(2,0,1)
102 |
103 | semseg = self.read_semseg(semseg_name)
104 | semseg = torch.from_numpy(semseg.astype(np.int32))
105 | dataset_dict['image'] = image
106 | dataset_dict['semseg'] = semseg
107 | return dataset_dict
--------------------------------------------------------------------------------
/datasets/dataset_mappers/scannet_pano_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates.
8 | import copy
9 |
10 | import scipy.io
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 |
15 | from torchvision import transforms
16 | from modeling.utils import configurable
17 |
18 | __all__ = ["ScanNetPanoDatasetMapper"]
19 |
20 |
21 | # This is specifically designed for the COCO dataset.
22 | class ScanNetPanoDatasetMapper:
23 | """
24 | A callable which takes a dataset dict in Detectron2 Dataset format,
25 | and map it into a format used by MaskFormer.
26 |
27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
28 |
29 | The callable currently does the following:
30 |
31 | 1. Read the image from "file_name"
32 | 2. Applies geometric transforms to the image and annotation
33 | 3. Find and applies suitable cropping to the image and annotation
34 | 4. Prepare image and annotation to Tensors
35 | """
36 |
37 | @configurable
38 | def __init__(
39 | self,
40 | is_train=True,
41 | min_size_test=None,
42 | max_size_test=None,
43 | mean=None,
44 | std=None,
45 | ):
46 | """
47 | NOTE: this interface is experimental.
48 | Args:
49 | is_train: for training or inference
50 | augmentations: a list of augmentations or deterministic transforms to apply
51 | tfm_gens: data augmentation
52 | image_format: an image format supported by :func:`detection_utils.read_image`.
53 | """
54 | self.is_train = is_train
55 | self.min_size_test = min_size_test
56 | self.max_size_test = max_size_test
57 | self.pixel_mean = torch.tensor(mean)[:,None,None]
58 | self.pixel_std = torch.tensor(std)[:,None,None]
59 |
60 | t = []
61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
62 | self.transform = transforms.Compose(t)
63 |
64 | @classmethod
65 | def from_config(cls, cfg, is_train=True):
66 | ret = {
67 | "is_train": is_train,
68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
70 | "mean": cfg['INPUT']['PIXEL_MEAN'],
71 | "std": cfg['INPUT']['PIXEL_STD'],
72 | }
73 | return ret
74 |
75 | def __call__(self, dataset_dict):
76 | """
77 | Args:
78 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
79 |
80 | Returns:
81 | dict: a format that builtin models in detectron2 accept
82 | """
83 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
84 | file_name = dataset_dict['file_name']
85 | image = Image.open(file_name).convert('RGB')
86 |
87 | dataset_dict['file_name'] = '_'.join(file_name.split('/')[-3:]) # HACK for /tmp file storage on predictions.
88 | dataset_dict['width'] = image.size[0]
89 | dataset_dict['height'] = image.size[1]
90 |
91 | image = self.transform(image)
92 | image = torch.from_numpy(np.asarray(image).copy())
93 | image = image.permute(2,0,1)
94 | dataset_dict['image'] = image
95 | return dataset_dict
--------------------------------------------------------------------------------
/datasets/dataset_mappers/sunrgbd_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates.
8 | import copy
9 |
10 | import scipy.io
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 |
15 | from torchvision import transforms
16 | from modeling.utils import configurable
17 |
18 | __all__ = ["SunRGBDSegDatasetMapper"]
19 |
20 |
21 | # This is specifically designed for the COCO dataset.
22 | class SunRGBDSegDatasetMapper:
23 | """
24 | A callable which takes a dataset dict in Detectron2 Dataset format,
25 | and map it into a format used by MaskFormer.
26 |
27 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
28 |
29 | The callable currently does the following:
30 |
31 | 1. Read the image from "file_name"
32 | 2. Applies geometric transforms to the image and annotation
33 | 3. Find and applies suitable cropping to the image and annotation
34 | 4. Prepare image and annotation to Tensors
35 | """
36 |
37 | @configurable
38 | def __init__(
39 | self,
40 | is_train=True,
41 | min_size_test=None,
42 | max_size_test=None,
43 | mean=None,
44 | std=None,
45 | ):
46 | """
47 | NOTE: this interface is experimental.
48 | Args:
49 | is_train: for training or inference
50 | augmentations: a list of augmentations or deterministic transforms to apply
51 | tfm_gens: data augmentation
52 | image_format: an image format supported by :func:`detection_utils.read_image`.
53 | """
54 | self.is_train = is_train
55 | self.min_size_test = min_size_test
56 | self.max_size_test = max_size_test
57 | self.pixel_mean = torch.tensor(mean)[:,None,None]
58 | self.pixel_std = torch.tensor(std)[:,None,None]
59 |
60 | t = []
61 | t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
62 | self.transform = transforms.Compose(t)
63 |
64 | @classmethod
65 | def from_config(cls, cfg, is_train=True):
66 | ret = {
67 | "is_train": is_train,
68 | "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
69 | "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
70 | "mean": cfg['INPUT']['PIXEL_MEAN'],
71 | "std": cfg['INPUT']['PIXEL_STD'],
72 | }
73 | return ret
74 |
75 | def read_semseg(self, file_name):
76 | if '.png' in file_name:
77 | semseg = np.asarray(Image.open(file_name))
78 | elif '.mat' in file_name:
79 | semseg = scipy.io.loadmat(file_name)['LabelMap']
80 | return semseg
81 |
82 | def __call__(self, dataset_dict):
83 | """
84 | Args:
85 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
86 |
87 | Returns:
88 | dict: a format that builtin models in detectron2 accept
89 | """
90 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
91 | file_name = dataset_dict['file_name']
92 | semseg_name = dataset_dict['sem_seg_file_name']
93 | image = Image.open(file_name).convert('RGB')
94 |
95 | dataset_dict['width'] = image.size[0]
96 | dataset_dict['height'] = image.size[1]
97 |
98 | if self.is_train == False:
99 | image = self.transform(image)
100 | image = torch.from_numpy(np.asarray(image).copy())
101 | image = image.permute(2,0,1)
102 |
103 | semseg = self.read_semseg(semseg_name)
104 | semseg = torch.from_numpy(semseg.astype(np.int32))
105 | dataset_dict['image'] = image
106 | dataset_dict['semseg'] = semseg
107 | return dataset_dict
--------------------------------------------------------------------------------
/datasets/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .instance_evaluation import *
2 | from .classification_evaluation import *
3 | from .segmentation_evaluation import *
4 | from .retrieval_evaluation import *
5 | from .captioning_evaluation import *
6 | from .panoptic_evaluation import *
7 | from .grounding_evaluation import *
8 | from .interactive_evaluation import *
--------------------------------------------------------------------------------
/datasets/evaluation/classification_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # --------------------------------------------------------
3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
4 | # Copyright (c) 2022 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
7 | # --------------------------------------------------------
8 |
9 | import torch
10 | import logging
11 |
12 | from detectron2.evaluation.evaluator import DatasetEvaluator
13 |
14 | from utils.misc import AverageMeter
15 | from utils.distributed import get_world_size
16 |
17 |
18 | @torch.no_grad()
19 | def accuracy(output, target, topk=(1,)):
20 | """Computes the precision@k for the specified values of k"""
21 | if isinstance(output, list):
22 | output = output[-1]
23 |
24 | n_classes = output.size()[1]
25 | maxk = min(max(topk), n_classes)
26 | batch_size = target.size(0)
27 | _, pred = output.topk(maxk, 1, True, True)
28 | pred = pred.t()
29 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
30 |
31 | res = []
32 | for k in topk:
33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
34 | res.append(correct_k.mul_(100.0 / batch_size).item())
35 | return res
36 |
37 | class ClassificationEvaluator(DatasetEvaluator):
38 | def __init__(self, *args):
39 | self.top1 = AverageMeter()
40 | self.top5 = AverageMeter()
41 | self._logger = logging.getLogger(__name__)
42 |
43 | def reset(self):
44 | self.top1.reset()
45 | self.top5.reset()
46 |
47 | def process(self, inputs, outputs):
48 | logits = torch.stack([o['pred_class'] for o in outputs])
49 | y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
50 | prec1, prec5 = accuracy(logits, y, (1, 5))
51 | self.top1.update(prec1, y.size(0))
52 | self.top5.update(prec5, y.size(0))
53 |
54 | def evaluate(self):
55 | if get_world_size() > 1:
56 | tmp_tensor = torch.tensor(
57 | [self.top1.sum, self.top5.sum, self.top1.count],
58 | device=torch.cuda.current_device()
59 | )
60 | torch.distributed.all_reduce(
61 | tmp_tensor, torch.distributed.ReduceOp.SUM
62 | )
63 | top1_sum, top5_sum, count = tmp_tensor.tolist()
64 | else:
65 | top1_sum = self.top1.sum
66 | top5_sum = self.top5.sum
67 | count = self.top1.count
68 |
69 | results = {}
70 | scores = {
71 | 'top1': top1_sum / count,
72 | "top5": top5_sum / count
73 | }
74 | results['class'] = scores
75 | self._logger.info(results)
76 | return results
77 |
--------------------------------------------------------------------------------
/datasets/evaluation/grounding_evaluation.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | import logging
8 | import torch
9 | from torchvision.ops import box_iou
10 |
11 | from detectron2.structures import BoxMode
12 | from detectron2.data import MetadataCatalog
13 | from detectron2.utils.comm import all_gather, is_main_process, synchronize
14 | from detectron2.evaluation.evaluator import DatasetEvaluator
15 |
16 |
17 | class GroundingEvaluator(DatasetEvaluator):
18 | """
19 | Evaluate grounding segmentation metrics.
20 | """
21 |
22 | def __init__(
23 | self,
24 | dataset_name,
25 | compute_box=False,
26 | distributed=True,
27 | ):
28 | self._logger = logging.getLogger(__name__)
29 | self._dataset_name = dataset_name
30 | self._distributed = distributed
31 | self._cpu_device = torch.device("cpu")
32 | self._compute_box = compute_box
33 | meta = MetadataCatalog.get(dataset_name)
34 |
35 | def reset(self):
36 | self.cum_I = 0
37 | self.cum_U = 0
38 | self.mIoU = 0
39 | self.eval_seg_iou_list = [.5, .6, .7, .8, .9]
40 | self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
41 | self.seg_total = 0
42 | if self._compute_box:
43 | self.mIoU_box = 0
44 | self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
45 |
46 | @staticmethod
47 | def computeIoU(pred_seg, gd_seg):
48 | I = (pred_seg & gd_seg)
49 | U = (pred_seg | gd_seg)
50 | return I, U
51 |
52 | def process(self, inputs, outputs):
53 | for input, output in zip(inputs, outputs):
54 | pred = output['grounding_mask'].sigmoid() > 0.5
55 | gt = input['groundings']['masks'].bool()
56 | bsi = len(pred)
57 | I, U = self.computeIoU(pred, gt)
58 | self.cum_I += I.sum().cpu()
59 | self.cum_U += U.sum().cpu()
60 | IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6)
61 | self.mIoU += IoU.sum().cpu()
62 |
63 | if self._compute_box:
64 | pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
65 | gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu()
66 | IoU_box = box_iou(pred_box, gt_box).diagonal()
67 | self.mIoU_box += IoU_box.sum()
68 |
69 | for idx in range(len(self.eval_seg_iou_list)):
70 | eval_seg_iou = self.eval_seg_iou_list[idx]
71 | self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu()
72 | if self._compute_box:
73 | self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu()
74 | self.seg_total += bsi
75 |
76 | def evaluate(self):
77 | if self._distributed:
78 | synchronize()
79 | self.cum_I = torch.stack(all_gather(self.cum_I)).sum()
80 | self.cum_U = torch.stack(all_gather(self.cum_U)).sum()
81 | self.mIoU = torch.stack(all_gather(self.mIoU)).sum()
82 | self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0)
83 | self.seg_total = sum(all_gather(self.seg_total))
84 |
85 | if self._compute_box:
86 | self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum()
87 | self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0)
88 | if not is_main_process():
89 | return
90 |
91 | results = {}
92 | for idx in range(len(self.eval_seg_iou_list)):
93 | result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx])
94 | results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item()
95 | results['cIoU'] = (self.cum_I*100./self.cum_U).item()
96 | results['mIoU'] = (self.mIoU*100./self.seg_total).item()
97 |
98 | if self._compute_box:
99 | for idx in range(len(self.eval_seg_iou_list)):
100 | result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx])
101 | results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item()
102 | results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item()
103 |
104 | self._logger.info(results)
105 | return {'grounding': results}
--------------------------------------------------------------------------------
/datasets/evaluation/instance_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import contextlib
3 | import copy
4 | import io
5 | import itertools
6 | import json
7 | import logging
8 | import numpy as np
9 | import os
10 | import pickle
11 | from collections import OrderedDict
12 | import pycocotools.mask as mask_util
13 | import torch
14 | from pycocotools.coco import COCO
15 | from pycocotools.cocoeval import COCOeval
16 | from tabulate import tabulate
17 |
18 | import detectron2.utils.comm as comm
19 | from detectron2.config import CfgNode
20 | from detectron2.data import MetadataCatalog
21 | from detectron2.data.datasets.coco import convert_to_coco_json
22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
23 | from detectron2.evaluation.fast_eval_api import COCOeval_opt
24 | from detectron2.structures import Boxes, BoxMode, pairwise_iou
25 | from detectron2.utils.file_io import PathManager
26 | from detectron2.utils.logger import create_small_table
27 |
28 |
29 | # modified from COCOEvaluator for instance segmetnat
30 | class InstanceSegEvaluator(COCOEvaluator):
31 | """
32 | Evaluate AR for object proposals, AP for instance detection/segmentation, AP
33 | for keypoint detection outputs using COCO's metrics.
34 | See http://cocodataset.org/#detection-eval and
35 | http://cocodataset.org/#keypoints-eval to understand its metrics.
36 | The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
37 | the metric cannot be computed (e.g. due to no predictions made).
38 |
39 | In addition to COCO, this evaluator is able to support any bounding box detection,
40 | instance segmentation, or keypoint detection dataset.
41 | """
42 |
43 | def _eval_predictions(self, predictions, img_ids=None):
44 | """
45 | Evaluate predictions. Fill self._results with the metrics of the tasks.
46 | """
47 | self._logger.info("Preparing results for COCO format ...")
48 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
49 | tasks = self._tasks or self._tasks_from_predictions(coco_results)
50 |
51 | # unmap the category ids for COCO
52 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
53 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
54 | # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
55 | # num_classes = len(all_contiguous_ids)
56 | # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
57 |
58 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
59 | for result in coco_results:
60 | category_id = result["category_id"]
61 | # assert category_id < num_classes, (
62 | # f"A prediction has class={category_id}, "
63 | # f"but the dataset only has {num_classes} classes and "
64 | # f"predicted class id should be in [0, {num_classes - 1}]."
65 | # )
66 | assert category_id in reverse_id_mapping, (
67 | f"A prediction has class={category_id}, "
68 | f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
69 | )
70 | result["category_id"] = reverse_id_mapping[category_id]
71 |
72 | if self._output_dir:
73 | file_path = os.path.join(self._output_dir, "coco_instances_results.json")
74 | self._logger.info("Saving results to {}".format(file_path))
75 | with PathManager.open(file_path, "w") as f:
76 | f.write(json.dumps(coco_results))
77 | f.flush()
78 |
79 | if not self._do_evaluation:
80 | self._logger.info("Annotations are not available for evaluation.")
81 | return
82 |
83 | self._logger.info(
84 | "Evaluating predictions with {} COCO API...".format(
85 | "unofficial" if self._use_fast_impl else "official"
86 | )
87 | )
88 | for task in sorted(tasks):
89 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
90 | coco_eval = (
91 | _evaluate_predictions_on_coco(
92 | self._coco_api,
93 | coco_results,
94 | task,
95 | kpt_oks_sigmas=self._kpt_oks_sigmas,
96 | use_fast_impl=self._use_fast_impl,
97 | img_ids=img_ids,
98 | max_dets_per_image=self._max_dets_per_image,
99 | )
100 | if len(coco_results) > 0
101 | else None # cocoapi does not handle empty results very well
102 | )
103 |
104 | res = self._derive_coco_results(
105 | coco_eval, task, class_names=self._metadata.get("thing_classes")
106 | )
107 | self._results[task] = res
108 |
--------------------------------------------------------------------------------
/datasets/evaluation/interactive_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import logging
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | from torchvision.ops import box_iou
8 |
9 | from detectron2.structures import BoxMode
10 | from detectron2.data import MetadataCatalog
11 | from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
12 | from detectron2.evaluation.evaluator import DatasetEvaluator
13 |
14 |
15 | class InteractiveEvaluator(DatasetEvaluator):
16 | """
17 | Evaluate point interactive IoU metrics.
18 | """
19 |
20 | def __init__(
21 | self,
22 | dataset_name,
23 | output_dir,
24 | max_clicks=20,
25 | iou_iter=1,
26 | compute_box=False,
27 | distributed=True,
28 | ):
29 | self._logger = logging.getLogger(__name__)
30 | self._dataset_name = dataset_name
31 | self._distributed = distributed
32 | self._cpu_device = torch.device("cpu")
33 | self._output_dir = output_dir
34 |
35 | self.max_clicks = max_clicks
36 | self.iou_iter = iou_iter
37 | meta = MetadataCatalog.get(dataset_name)
38 |
39 | def reset(self):
40 | self.iou_list = []
41 | self.num_samples = 0
42 | self.all_ious = [0.5, 0.8, 0.85, 0.9]
43 |
44 | def process(self, inputs, outputs):
45 | self.iou_list += [o['mask_iou'] for o in outputs]
46 | self.num_samples += len(outputs)
47 |
48 | def compute_noc(self):
49 | def _get_noc(iou_arr, iou_thr):
50 | vals = iou_arr >= iou_thr
51 | return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
52 |
53 | noc_list = {}
54 | for iou_thr in self.all_ious:
55 | scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
56 | noc_list[str(iou_thr)] = scores_arr
57 |
58 | iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
59 | noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
60 |
61 | if self._distributed:
62 | num_samples = sum(all_gather(self.num_samples))
63 | noc_list_sum_gather = all_gather(noc_list_sum)
64 | iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
65 |
66 | noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
67 | for nlg in noc_list_sum_gather:
68 | for key, value in nlg.items():
69 | noc_list_sum[key] += value
70 |
71 | pred_noc = {}
72 | if self._distributed and (not is_main_process()):
73 | return pred_noc
74 |
75 | for key, value in noc_list_sum.items():
76 | pred_noc[key] = value / num_samples
77 |
78 | pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
79 | return pred_noc
80 |
81 | def evaluate(self):
82 | pred_noc = self.compute_noc()
83 |
84 | if self._distributed and (not is_main_process()):
85 | return
86 |
87 | def draw_iou_curve(iou_list, save_dir):
88 | iou_list = torch.stack(iou_list, dim=0)
89 | iou_list = iou_list.mean(dim=0).cpu().numpy()
90 | # draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
91 | import matplotlib.pyplot as plt
92 | plt.figure()
93 | plt.plot(range(1, self.max_clicks+1), iou_list)
94 | plt.xlabel('Number of clicks')
95 | plt.ylabel('IoU')
96 |
97 |
98 | # create directory if not exist
99 | import os
100 | output_dir = os.path.join(save_dir, 'iou_by_clicks')
101 | if not os.path.exists(output_dir):
102 | os.makedirs(output_dir)
103 |
104 | # get current time and format in 10 digits
105 | import time
106 | current_time = time.time()
107 | current_time = int(current_time)
108 | current_time = str(current_time)
109 |
110 | # save iou curve
111 | plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
112 |
113 | draw_iou_curve(self.iou_list, self._output_dir)
114 | results = {}
115 | for idx in range(len(self.all_ious)):
116 | result_str = 'noc@{}'.format(self.all_ious[idx])
117 | results[result_str] = pred_noc[str(self.all_ious[idx])]
118 |
119 | results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
120 |
121 | self._logger.info(results)
122 | return {'interactive': results}
--------------------------------------------------------------------------------
/datasets/registration/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | register_refcoco_dataset,
3 | register_ade20k_full,
4 | register_ade20k_panoptic,
5 | register_coco_stuff_10k,
6 | register_coco_panoptic_annos_semseg,
7 | register_coco_panoptic_annos_caption,
8 | register_coco_panoptic_annos_caption_grounding,
9 | register_coco_lvis_panoptic_annos_caption_grounding,
10 | register_ade20k_instance,
11 | register_vlp_datasets,
12 | register_sunrgbd_semseg,
13 | register_scannet_semseg,
14 | register_bdd100k_semseg,
15 | register_scannet_panoptic,
16 | register_bdd100k_panoseg,
17 | register_pascalvoc_eval,
18 | )
--------------------------------------------------------------------------------
/datasets/registration/register_ade20k_instance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import json
3 | import logging
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 |
8 | from detectron2.data import DatasetCatalog, MetadataCatalog
9 | from detectron2.data.datasets.coco import load_coco_json, register_coco_instances
10 | from detectron2.utils.file_io import PathManager
11 |
12 | ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}]
13 |
14 |
15 | _PREDEFINED_SPLITS = {
16 | # point annotations without masks
17 | "ade20k_instance_train": (
18 | "ADEChallengeData2016/images/training",
19 | "ADEChallengeData2016/ade20k_instance_train.json",
20 | ),
21 | "ade20k_instance_val": (
22 | "ADEChallengeData2016/images/validation",
23 | "ADEChallengeData2016/ade20k_instance_val.json",
24 | ),
25 | }
26 |
27 |
28 | def _get_ade_instances_meta():
29 | thing_ids = [k["id"] for k in ADE_CATEGORIES]
30 | assert len(thing_ids) == 100, len(thing_ids)
31 | # Mapping from the incontiguous ADE category id to an id in [0, 99]
32 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
33 | thing_classes = [k["name"] for k in ADE_CATEGORIES]
34 | ret = {
35 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
36 | "thing_classes": thing_classes,
37 | }
38 | return ret
39 |
40 |
41 | def register_all_ade20k_instance(root):
42 | for key, (image_root, json_file) in _PREDEFINED_SPLITS.items():
43 | # Assume pre-defined datasets live in `./datasets`.
44 | register_coco_instances(
45 | key,
46 | _get_ade_instances_meta(),
47 | os.path.join(root, json_file) if "://" not in json_file else json_file,
48 | os.path.join(root, image_root),
49 | )
50 |
51 |
52 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
53 | register_all_ade20k_instance(_root)
54 |
--------------------------------------------------------------------------------
/datasets/registration/register_bdd100k_semseg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 | import os
10 | import glob
11 | from typing import List, Tuple, Union
12 |
13 | from detectron2.data import DatasetCatalog, MetadataCatalog
14 | from detectron2.utils.file_io import PathManager
15 |
16 | from utils.constants import BDD_SEM
17 |
18 | __all__ = ["load_scannet_instances", "register_scannet_context"]
19 |
20 |
21 | def load_bdd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
22 | """
23 | Load BDD annotations to Detectron2 format.
24 |
25 | Args:
26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
27 | split (str): one of "train", "test", "val", "trainval"
28 | class_names: list or tuple of class names
29 | """
30 | img_folder = os.path.join(dirname, 'images', '10k', split)
31 | img_pths = sorted(glob.glob(os.path.join(img_folder, '*.jpg')))
32 |
33 | sem_folder = os.path.join(dirname, 'labels', 'sem_seg', 'masks', split)
34 | sem_pths = sorted(glob.glob(os.path.join(sem_folder, '*.png')))
35 |
36 | assert len(img_pths) == len(sem_pths)
37 |
38 | dicts = []
39 | for img_pth, sem_pth in zip(img_pths, sem_pths):
40 | r = {
41 | "file_name": img_pth,
42 | "sem_seg_file_name": sem_pth,
43 | "image_id": img_pth.split('/')[-1].split('.')[0],
44 | }
45 | dicts.append(r)
46 | return dicts
47 |
48 |
49 | def register_bdd_context(name, dirname, split, class_names=BDD_SEM):
50 | DatasetCatalog.register(name, lambda: load_bdd_instances(name, dirname, split, class_names))
51 | MetadataCatalog.get(name).set(
52 | stuff_classes=class_names,
53 | dirname=dirname,
54 | split=split,
55 | ignore_label=[255],
56 | thing_dataset_id_to_contiguous_id={},
57 | class_offset=0,
58 | keep_sem_bgd=False
59 | )
60 |
61 |
62 | def register_all_sunrgbd_seg(root):
63 | SPLITS = [
64 | ("bdd10k_val_sem_seg", "bdd100k", "val"),
65 | ]
66 |
67 | for name, dirname, split in SPLITS:
68 | register_bdd_context(name, os.path.join(root, dirname), split)
69 | MetadataCatalog.get(name).evaluator_type = "sem_seg"
70 |
71 |
72 | _root = os.getenv("DATASET", "datasets")
73 | register_all_sunrgbd_seg(_root)
--------------------------------------------------------------------------------
/datasets/registration/register_imagenet_cls.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import glob
10 | from typing import List, Tuple, Union
11 |
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 | from detectron2.utils.file_io import PathManager
15 |
16 | from utils.constants import IMAGENET_CLASSES, IMAGENET_FOLDER_NAMES
17 |
18 | __all__ = ["load_imagenet_images", "register_imagenet"]
19 |
20 |
21 | def load_imagenet_images(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
22 | """
23 | Load ImageNet annotations to Detectron2 format.
24 |
25 | Args:
26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
27 | split (str): one of "train", "test", "val", "trainval"
28 | class_names: list or tuple of class names
29 | """
30 | image_folders = sorted(glob.glob(os.path.join(dirname, split, 'n*')))
31 |
32 | dicts = []
33 | for image_folder in image_folders:
34 | folder_name = image_folder.split('/')[-1]
35 | image_pths = sorted(glob.glob(os.path.join(image_folder, "*.JPEG")))
36 | for img_pth in image_pths:
37 | r = {
38 | "file_name": img_pth,
39 | "class_name": IMAGENET_CLASSES[IMAGENET_FOLDER_NAMES.index(folder_name)],
40 | "class_id": IMAGENET_FOLDER_NAMES.index(folder_name),
41 | }
42 | dicts.append(r)
43 | return dicts
44 |
45 |
46 | def register_imagenet(name, dirname, split, year, class_names=IMAGENET_CLASSES):
47 | DatasetCatalog.register(name, lambda: load_imagenet_images(dirname, split, class_names))
48 | MetadataCatalog.get(name).set(
49 | thing_classes=list(class_names), dirname=dirname, year=year, split=split
50 | )
51 |
52 |
53 | def register_all_imagenet(root):
54 | SPLITS = [
55 | ("imagenet_val", "imagenet", "val", "2012"),
56 | ]
57 | for name, dirname, split, year in SPLITS:
58 | register_imagenet(name, os.path.join(root, dirname), split, year)
59 | MetadataCatalog.get(name).evaluator_type = "classification"
60 |
61 |
62 | _root = os.getenv("DATASET", "datasets")
63 | register_all_imagenet(_root)
--------------------------------------------------------------------------------
/datasets/registration/register_pascalvoc_eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | import os
4 | import glob
5 | from typing import List, Tuple, Union
6 | import xml.etree.ElementTree as ET
7 |
8 | import cv2
9 | import numpy as np
10 | from scipy.io import loadmat
11 |
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 | from detectron2.utils.file_io import PathManager
15 |
16 |
17 | __all__ = ["load_pascalvoc_instances", "register_pascalvoc_context"]
18 |
19 | def get_labels_with_sizes(x):
20 | obj_sizes = np.bincount(x.flatten())
21 | labels = np.nonzero(obj_sizes)[0].tolist()
22 | labels = [x for x in labels if x != 0]
23 | return labels, obj_sizes[labels].tolist()
24 |
25 | def load_pascalvoc_instances(name: str, dirname: str, mode: str, split: str):
26 | """
27 | Load Pascal VOC detection annotations to Detectron2 format.
28 |
29 | Args:
30 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
31 | split (str): one of "train", "test", "val", "trainval"
32 | class_names: list or tuple of class names
33 | """
34 | with PathManager.open(os.path.join(dirname, 'ImageSets', 'Segmentation', split + ".txt")) as f:
35 | fileids = np.loadtxt(f, dtype=np.str)
36 |
37 | dicts = []
38 | for field in fileids:
39 | anno_path = os.path.join(dirname, "Annotations", "{}.xml".format(field))
40 | image_path = os.path.join(dirname, "JPEGImages", "{}.jpg".format(field))
41 | inst_path = os.path.join(dirname, "SegmentationObject", "{}.png".format(field))
42 | semseg_path = os.path.join(dirname, "SegmentationClass", "{}.png".format(field))
43 |
44 | instances_mask = cv2.imread(inst_path)
45 | instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32)
46 |
47 | objects_ids = np.unique(instances_mask)
48 | objects_ids = [x for x in objects_ids if x != 0 and x != 220]
49 |
50 | slice_size = 5
51 | for i in range(0, len(objects_ids), slice_size):
52 | r = {
53 | "file_name": image_path,
54 | "inst_name": inst_path,
55 | "semseg_name": semseg_path,
56 | "objects_ids": objects_ids[i:i+slice_size],
57 | }
58 | dicts.append(r)
59 | return dicts
60 |
61 | def register_pascalvoc_context(name, dirname, mode, split):
62 | DatasetCatalog.register("{}_{}".format(name, mode), lambda: load_pascalvoc_instances(name, dirname, mode, split))
63 | MetadataCatalog.get("{}_{}".format(name, mode)).set(
64 | dirname=dirname,
65 | thing_dataset_id_to_contiguous_id={},
66 | )
67 |
68 | def register_all_sbd(root):
69 | SPLITS = [
70 | ("pascalvoc_val", "PascalVOC", "Point", "val"),
71 | ("pascalvoc_val", "PascalVOC", "Scribble", "val"),
72 | ("pascalvoc_val", "PascalVOC", "Polygon", "val"),
73 | ("pascalvoc_val", "PascalVOC", "Circle", "val"),
74 | ("pascalvoc_val", "PascalVOC", "Box", "val"),
75 | ]
76 |
77 | for name, dirname, mode, split in SPLITS:
78 | register_pascalvoc_context(name, os.path.join(root, dirname), mode, split)
79 | MetadataCatalog.get("{}_{}".format(name, mode)).evaluator_type = "interactive"
80 |
81 | _root = os.getenv("DATASET", "datasets")
82 | register_all_sbd(_root)
--------------------------------------------------------------------------------
/datasets/registration/register_refcoco_dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | import json
8 | import os
9 | import collections
10 |
11 | from detectron2.data import DatasetCatalog, MetadataCatalog
12 | from detectron2.data.datasets import load_sem_seg
13 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
14 | from detectron2.utils.file_io import PathManager
15 |
16 |
17 | _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = {
18 | # "refcocog_train_umd": (
19 | # "coco/train2017", # image_root
20 | # "coco/annotations/refcocog_umd_train.json", # annot_root
21 | # ),
22 | # "refcocog_val_google": (
23 | # "coco/train2017", # image_root
24 | # "coco/annotations/refcocog_google.json", # annot_root
25 | # ),
26 | # "refcocop_val_unc": (
27 | # "coco/train2017", # image_root
28 | # "coco/annotations/refcocop_unc.json", # annot_root
29 | # ),
30 | # "refcoco_val_unc": (
31 | # "coco/train2017", # image_root
32 | # "coco/annotations/refcoco_unc.json", # annot_root
33 | # ),
34 | "refcocog_val_umd": (
35 | "coco/train2017", # image_root
36 | "coco/annotations/refcocog_umd_val.json", # annot_root
37 | ),
38 | }
39 |
40 |
41 | def get_metadata():
42 | meta = {}
43 | return meta
44 |
45 |
46 | def load_refcoco_json(image_root, annot_json, metadata):
47 | """
48 | Args:
49 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
50 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
51 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
52 | Returns:
53 | list[dict]: a list of dicts in Detectron2 standard format. (See
54 | `Using Custom Datasets `_ )
55 | """
56 |
57 | with PathManager.open(annot_json) as f:
58 | json_info = json.load(f)
59 |
60 | # build dictionary for grounding
61 | grd_dict = collections.defaultdict(list)
62 | for grd_ann in json_info['annotations']:
63 | image_id = int(grd_ann["image_id"])
64 | grd_dict[image_id].append(grd_ann)
65 |
66 | ret = []
67 | for image in json_info["images"]:
68 | image_id = int(image["id"])
69 | image_file = os.path.join(image_root, image['file_name'])
70 | grounding_anno = grd_dict[image_id]
71 | ret.append(
72 | {
73 | "file_name": image_file,
74 | "image_id": image_id,
75 | "grounding_info": grounding_anno,
76 | }
77 | )
78 | assert len(ret), f"No images found in {image_root}!"
79 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
80 | return ret
81 |
82 |
83 | def register_refcoco(
84 | name, metadata, image_root, annot_json):
85 | DatasetCatalog.register(
86 | name,
87 | lambda: load_refcoco_json(image_root, annot_json, metadata),
88 | )
89 | MetadataCatalog.get(name).set(
90 | image_root=image_root,
91 | json_file=annot_json,
92 | evaluator_type="grounding_refcoco",
93 | ignore_label=255,
94 | label_divisor=1000,
95 | **metadata,
96 | )
97 |
98 |
99 | def register_all_refcoco(root):
100 | for (
101 | prefix,
102 | (image_root, annot_root),
103 | ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items():
104 | register_refcoco(
105 | prefix,
106 | get_metadata(),
107 | os.path.join(root, image_root),
108 | os.path.join(root, annot_root),
109 | )
110 |
111 |
112 | _root = os.getenv("DATASET", "datasets")
113 | register_all_refcoco(_root)
114 |
--------------------------------------------------------------------------------
/datasets/registration/register_scannet_semseg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import os
9 | import glob
10 | from typing import List, Tuple, Union
11 |
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 | from detectron2.utils.file_io import PathManager
15 |
16 | from utils.constants import SCAN_37, SCAN_40, SCAN_20
17 |
18 | __all__ = ["load_scannet_instances", "register_scannet_context"]
19 |
20 | name2folder = {"scannet_41_val_seg": "label41",
21 | "scannet_38_val_seg": "label38",
22 | "scannet_21_val_seg": "label21",}
23 |
24 | name2class = {"scannet_41_val_seg": SCAN_40,
25 | "scannet_38_val_seg": SCAN_37,
26 | "scannet_21_val_seg": SCAN_20}
27 |
28 |
29 | def load_scannet_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
30 | """
31 | Load ScanNet annotations to Detectron2 format.
32 |
33 | Args:
34 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
35 | split (str): one of "train", "test", "val", "trainval"
36 | class_names: list or tuple of class names
37 | """
38 | with PathManager.open(os.path.join(dirname, "meta", split + ".txt")) as f:
39 | fileids = np.loadtxt(f, dtype=np.str)
40 |
41 | dicts = []
42 | for field in fileids:
43 | image_dir = os.path.join(dirname, 'images', field[0])
44 | semseg_dir = image_dir.replace('color', name2folder[name]).replace('jpg', 'png')
45 | r = {
46 | "file_name": image_dir,
47 | "sem_seg_file_name": semseg_dir,
48 | "image_id": semseg_dir.split('/')[-3] + semseg_dir.split('/')[-1].split('.')[0],
49 | }
50 | dicts.append(r)
51 | return dicts
52 |
53 |
54 | def register_scannet_context(name, dirname, split, class_names=name2class):
55 | DatasetCatalog.register(name, lambda: load_scannet_instances(name, dirname, split, class_names))
56 | MetadataCatalog.get(name).set(
57 | stuff_classes=class_names[name],
58 | dirname=dirname,
59 | split=split,
60 | ignore_label=[0],
61 | thing_dataset_id_to_contiguous_id={},
62 | class_offset=1,
63 | keep_sem_bgd=False
64 | )
65 |
66 |
67 | def register_all_sunrgbd_seg(root):
68 | SPLITS = [
69 | ("scannet_41_val_seg", "scannet_frames_25k", "val"),
70 | ("scannet_38_val_seg", "scannet_frames_25k", "val"),
71 | ("scannet_21_val_seg", "scannet_frames_25k", "val"),
72 | ]
73 |
74 | for name, dirname, split in SPLITS:
75 | register_scannet_context(name, os.path.join(root, dirname), split)
76 | MetadataCatalog.get(name).evaluator_type = "sem_seg"
77 |
78 |
79 | _root = os.getenv("DATASET", "datasets")
80 | register_all_sunrgbd_seg(_root)
--------------------------------------------------------------------------------
/datasets/registration/register_sunrgbd_semseg.py:
--------------------------------------------------------------------------------
1 |
2 | # --------------------------------------------------------
3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
4 | # Copyright (c) 2022 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
7 | # --------------------------------------------------------
8 | import numpy as np
9 | import os
10 | import glob
11 | from typing import List, Tuple, Union
12 |
13 | from detectron2.data import DatasetCatalog, MetadataCatalog
14 | from detectron2.structures import BoxMode
15 | from detectron2.utils.file_io import PathManager
16 |
17 | from utils.constants import SUN_RGBD_37
18 |
19 | __all__ = ["load_sunrgbd_instances", "register_sunrgbd_context"]
20 |
21 | def load_sunrgbd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
22 | """
23 | Load SUN-RGBD detection annotations to Detectron2 format.
24 |
25 | Args:
26 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
27 | split (str): one of "train", "test", "val", "trainval"
28 | class_names: list or tuple of class names
29 | """
30 | if split == 'val':
31 | split = 'test'
32 |
33 | # Needs to read many small annotation files. Makes sense at local
34 | image_pths = sorted(glob.glob(os.path.join(dirname, 'image', split, '*.jpg')))
35 | semseg_pths = sorted(glob.glob(os.path.join(dirname, 'label37', split, '*.png')))
36 |
37 | assert len(image_pths) == len(semseg_pths)
38 |
39 | dicts = []
40 | for image_dir, semseg_dir in zip(image_pths, semseg_pths):
41 | r = {
42 | "file_name": image_dir,
43 | "sem_seg_file_name": semseg_dir,
44 | "image_id": semseg_dir.split('/')[-1].split('.')[0],
45 | }
46 | dicts.append(r)
47 | return dicts
48 |
49 |
50 | def register_sun_context(name, dirname, split, class_names=SUN_RGBD_37):
51 | DatasetCatalog.register(name, lambda: load_sunrgbd_instances(name, dirname, split, class_names))
52 | MetadataCatalog.get(name).set(
53 | stuff_classes=class_names,
54 | dirname=dirname,
55 | split=split,
56 | ignore_label=[0],
57 | thing_dataset_id_to_contiguous_id={},
58 | class_offset=1,
59 | keep_sem_bgd=False
60 | )
61 |
62 |
63 | def register_all_sunrgbd_seg(root):
64 | SPLITS = [
65 | ("sunrgbd_37_val_seg", "sun_rgbd", "val"),
66 | ]
67 |
68 | for name, dirname, split in SPLITS:
69 | register_sun_context(name, os.path.join(root, dirname), split)
70 | MetadataCatalog.get(name).evaluator_type = "sem_seg"
71 |
72 |
73 | _root = os.getenv("DATASET", "datasets")
74 | register_all_sunrgbd_seg(_root)
--------------------------------------------------------------------------------
/datasets/semseg_loader.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import scipy.io
3 | import numpy as np
4 |
5 | def load_semseg(filename, loader_type):
6 | if loader_type == 'PIL':
7 | semseg = np.array(Image.open(filename), dtype=np.int)
8 | elif loader_type == 'MAT':
9 | semseg = scipy.io.loadmat(filename)['LabelMap']
10 | return semseg
--------------------------------------------------------------------------------
/datasets/utils/refcoco2json.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from refer import REFER
4 |
5 | coco_root = '/pth/to/coco'
6 | ref_root = '/pth/to/refcocoseg'
7 |
8 | coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json')))
9 | coco_train_id = []
10 | image_annot = {}
11 | for i in range(len(coco_train_annot['images'])):
12 | coco_train_id.append(coco_train_annot['images'][i]['id'])
13 | image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i]
14 |
15 | refg = REFER(data_root=ref_root,
16 | dataset='refcocog', splitBy='umd')
17 | refg_val_ids = refg.getRefIds(split='val')
18 |
19 | full_anno = []
20 | for ref_id in refg_val_ids:
21 | ref = refg.loadRefs(ref_id)[0]
22 | anno = refg.refToAnn[ref_id]
23 | anno.update(ref)
24 | full_anno.append(anno)
25 |
26 | imageid_list = []
27 | final_anno = {}
28 | for anno in full_anno:
29 | imageid_list += [anno['image_id']]
30 | final_anno[anno['ann_id']] = anno
31 |
32 | annotations = [value for key, value in final_anno.items()]
33 |
34 | iamges = []
35 | for image_id in list(set(imageid_list)):
36 | iamges += [image_annot[image_id]]
37 |
38 | outputs = {'images': iamges, 'annotations': annotations}
39 | print(len(iamges))
40 | print(len(annotations))
41 | json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w'))
42 |
--------------------------------------------------------------------------------
/datasets/visual_sampler/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import ShapeSampler
2 | from .simpleclick_sampler import SimpleClickSampler
3 |
4 |
5 | def build_shape_sampler(cfg, **kwargs):
6 | sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE']
7 | if sampler_name == 'random':
8 | return ShapeSampler(cfg, **kwargs)
9 | elif sampler_name in ['best', 'best_random']:
10 | return SimpleClickSampler(cfg, **kwargs)
11 | else:
12 | assert False, "not implemented"
--------------------------------------------------------------------------------
/datasets/visual_sampler/circle.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | from .mask_generators import get_mask_by_input_strokes
5 |
6 | class Circle:
7 | def __init__(self, cfg, is_train=True):
8 | self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES']
9 | self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET']
10 | self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB']
11 | self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
12 | self.is_train = is_train
13 |
14 | @staticmethod
15 | def get_stroke_preset(stroke_preset):
16 | if stroke_preset == 'object_like':
17 | return {
18 | "nVertexBound": [5, 30],
19 | "maxHeadSpeed": 15,
20 | "maxHeadAcceleration": (10, 1.5),
21 | "brushWidthBound": (20, 50),
22 | "nMovePointRatio": 0.5,
23 | "maxPiontMove": 10,
24 | "maxLineAcceleration": (5, 0.5),
25 | "boarderGap": None,
26 | "maxInitSpeed": 10,
27 | }
28 | elif stroke_preset == 'object_like_middle':
29 | return {
30 | "nVertexBound": [5, 15],
31 | "maxHeadSpeed": 8,
32 | "maxHeadAcceleration": (4, 1.5),
33 | "brushWidthBound": (20, 50),
34 | "nMovePointRatio": 0.5,
35 | "maxPiontMove": 5,
36 | "maxLineAcceleration": (5, 0.5),
37 | "boarderGap": None,
38 | "maxInitSpeed": 10,
39 | }
40 | elif stroke_preset == 'object_like_small':
41 | return {
42 | "nVertexBound": [5, 20],
43 | "maxHeadSpeed": 7,
44 | "maxHeadAcceleration": (3.5, 1.5),
45 | "brushWidthBound": (10, 30),
46 | "nMovePointRatio": 0.5,
47 | "maxPiontMove": 5,
48 | "maxLineAcceleration": (3, 0.5),
49 | "boarderGap": None,
50 | "maxInitSpeed": 4,
51 | }
52 | else:
53 | raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
54 |
55 | def get_random_points_from_mask(self, mask, n=5):
56 | h,w = mask.shape
57 | view_mask = mask.reshape(h*w)
58 | non_zero_idx = view_mask.nonzero()[:,0]
59 | selected_idx = torch.randperm(len(non_zero_idx))[:n]
60 | non_zero_idx = non_zero_idx[selected_idx]
61 | y = (non_zero_idx // w)*1.0
62 | x = (non_zero_idx % w)*1.0
63 | return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
64 |
65 | def draw(self, mask=None, box=None):
66 | if mask.sum() < 10: # if mask is nearly empty
67 | return torch.zeros(mask.shape).bool()
68 | if not self.is_train:
69 | return self.draw_eval(mask=mask, box=box)
70 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
71 | preset = Circle.get_stroke_preset(stroke_preset_name)
72 | nStroke = min(random.randint(1, self.num_stroke), mask.sum().item())
73 | h,w = mask.shape
74 | points = self.get_random_points_from_mask(mask, n=nStroke)
75 | rand_mask = get_mask_by_input_strokes(
76 | init_points=points,
77 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
78 | rand_mask = (~torch.from_numpy(rand_mask)) * mask
79 | return rand_mask
80 |
81 | def draw_eval(self, mask=None, box=None):
82 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
83 | preset = Circle.get_stroke_preset(stroke_preset_name)
84 | nStroke = min(self.max_eval, mask.sum().item())
85 | h,w = mask.shape
86 | points = self.get_random_points_from_mask(mask, n=nStroke)
87 | rand_masks = []
88 | for i in range(len(points)):
89 | rand_mask = get_mask_by_input_strokes(
90 | init_points=points[:i+1],
91 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset)
92 | rand_masks += [(~torch.from_numpy(rand_mask)) * mask]
93 | return torch.stack(rand_masks)
94 |
95 | @staticmethod
96 | def draw_by_points(points, mask, h, w):
97 | stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use
98 | preset = Circle.get_stroke_preset(stroke_preset_name)
99 | rand_mask = get_mask_by_input_strokes(
100 | init_points=points,
101 | imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
102 | rand_masks = (~torch.from_numpy(rand_mask)) * mask
103 | return rand_masks
104 |
105 | def __repr__(self,):
106 | return 'circle'
--------------------------------------------------------------------------------
/datasets/visual_sampler/point.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from scipy import ndimage
6 |
7 |
8 | class Point:
9 | def __init__(self, cfg, is_train=True):
10 | self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
11 | self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
12 | self.is_train = is_train
13 |
14 | def draw(self, mask=None, box=None):
15 | if mask.sum() < 10:
16 | return torch.zeros(mask.shape).bool() # if mask is empty
17 | if not self.is_train:
18 | return self.draw_eval(mask=mask, box=box)
19 | max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
20 | num_points = random.randint(1, max_points) # get a random number of points
21 | h,w = mask.shape
22 | view_mask = mask.view(-1)
23 | non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
24 | selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
25 | non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
26 | rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
27 | rand_mask[non_zero_idx] = True # get non zero place to zero
28 | # dilate
29 | # struct = ndimage.generate_binary_structure(2, 2)
30 | # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
31 | # return rand_mask
32 | return rand_mask.reshape(h, w)
33 |
34 | def draw_eval(self, mask=None, box=None):
35 | background = ~mask
36 | neg_num = min(self.max_eval // 2, background.sum().item())
37 | pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
38 |
39 | h,w = mask.shape
40 | view_mask = mask.view(-1)
41 | non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
42 | selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
43 | non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
44 | pos_idx = torch.ones(non_zero_idx_pos.shape)
45 |
46 | view_background = background.view(-1)
47 | non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
48 | selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
49 | non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
50 | neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
51 |
52 | non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
53 | idx = torch.cat([pos_idx, neg_idx])
54 | rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
55 | non_zero_idx = non_zero_idx[rand_idx]
56 | idx = idx[rand_idx]
57 |
58 | rand_masks = []
59 | for i in range(0, len(non_zero_idx)):
60 | rand_mask = torch.zeros(view_mask.shape) # init rand mask
61 | rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
62 | # struct = ndimage.generate_binary_structure(2, 2)
63 | # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
64 | rand_masks += [rand_mask.reshape(h, w)]
65 |
66 | # kernel_size = 3
67 | rand_masks = torch.stack(rand_masks)
68 | # rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
69 | # rand_masks[rand_masks>0] = 1
70 | # rand_masks[rand_masks<0] = -1
71 | return rand_masks
72 |
73 | def __repr__(self,):
74 | return 'point'
--------------------------------------------------------------------------------
/datasets/visual_sampler/sampler.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import random
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .point import Point
8 | from .polygon import Polygon
9 | from .scribble import Scribble
10 | from .circle import Circle
11 |
12 | from modeling.utils import configurable
13 |
14 |
15 | class ShapeSampler(nn.Module):
16 | @configurable
17 | def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
18 | super().__init__()
19 | self.max_candidate = max_candidate
20 | self.shape_prob = shape_prob
21 | self.shape_candidate = shape_candidate
22 | self.is_train = is_train
23 |
24 | @classmethod
25 | def from_config(cls, cfg, is_train=True, mode=None):
26 | max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
27 | candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
28 | candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
29 |
30 | if mode == 'hack_train':
31 | candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
32 | else:
33 | # overwrite condidate_prob
34 | if not is_train:
35 | candidate_probs = [0.0 for x in range(len(candidate_names))]
36 | candidate_probs[candidate_names.index(mode)] = 1.0
37 | candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
38 |
39 | # Build augmentation
40 | return {
41 | "max_candidate": max_candidate,
42 | "shape_prob": candidate_probs,
43 | "shape_candidate": candidate_classes,
44 | "is_train": is_train,
45 | }
46 |
47 | def forward(self, instances):
48 | masks = instances.gt_masks.tensor
49 | boxes = instances.gt_boxes.tensor
50 |
51 | if len(masks) == 0:
52 | gt_masks = torch.zeros(masks.shape[-2:]).bool()
53 | rand_masks = torch.zeros(masks.shape[-2:]).bool()
54 | return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
55 | indices = [x for x in range(len(masks))]
56 |
57 | if self.is_train:
58 | random.shuffle(indices)
59 | candidate_mask = masks[indices[:self.max_candidate]]
60 | candidate_box = boxes[indices[:self.max_candidate]]
61 | else:
62 | candidate_mask = masks
63 | candidate_box = boxes
64 |
65 | draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
66 | rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
67 | types = [repr(x) for x in draw_funcs]
68 | for i in range(0, len(rand_shapes)):
69 | if rand_shapes[i].sum() == 0:
70 | candidate_mask[i] = candidate_mask[i] * 0
71 | types[i] = 'none'
72 |
73 | # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
74 | return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
75 |
76 | def build_shape_sampler(cfg, **kwargs):
77 | return ShapeSampler(cfg, **kwargs)
--------------------------------------------------------------------------------
/datasets/visual_sampler/scribble.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 |
5 | from .mask_generators import get_mask_by_input_strokes
6 |
7 | class Scribble:
8 | def __init__(self, cfg, is_train):
9 | self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES']
10 | self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET']
11 | self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB']
12 | self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
13 | self.is_train = is_train
14 |
15 | @staticmethod
16 | def get_stroke_preset(stroke_preset):
17 | if stroke_preset == 'rand_curve':
18 | return {
19 | "nVertexBound": [10, 30],
20 | "maxHeadSpeed": 20,
21 | "maxHeadAcceleration": (15, 0.5),
22 | "brushWidthBound": (3, 10),
23 | "nMovePointRatio": 0.5,
24 | "maxPiontMove": 3,
25 | "maxLineAcceleration": (5, 0.5),
26 | "boarderGap": None,
27 | "maxInitSpeed": 6
28 | }
29 | elif stroke_preset == 'rand_curve_small':
30 | return {
31 | "nVertexBound": [6, 22],
32 | "maxHeadSpeed": 12,
33 | "maxHeadAcceleration": (8, 0.5),
34 | "brushWidthBound": (2.5, 5),
35 | "nMovePointRatio": 0.5,
36 | "maxPiontMove": 1.5,
37 | "maxLineAcceleration": (3, 0.5),
38 | "boarderGap": None,
39 | "maxInitSpeed": 3
40 | }
41 | else:
42 | raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
43 |
44 | def get_random_points_from_mask(self, mask, n=5):
45 | h,w = mask.shape
46 | view_mask = mask.reshape(h*w)
47 | non_zero_idx = view_mask.nonzero()[:,0]
48 | selected_idx = torch.randperm(len(non_zero_idx))[:n]
49 | non_zero_idx = non_zero_idx[selected_idx]
50 | y = (non_zero_idx // w)*1.0
51 | x = (non_zero_idx % w)*1.0
52 | return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
53 |
54 | def draw(self, mask=None, box=None):
55 | if mask.sum() < 10:
56 | return torch.zeros(mask.shape).bool() # if mask is empty
57 | if not self.is_train:
58 | return self.draw_eval(mask=mask, box=box)
59 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
60 | preset = Scribble.get_stroke_preset(stroke_preset_name)
61 | nStroke = random.randint(1, min(self.num_stroke, mask.sum().item()))
62 | h,w = mask.shape
63 | points = self.get_random_points_from_mask(mask, n=nStroke)
64 | rand_mask = get_mask_by_input_strokes(
65 | init_points=points,
66 | imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
67 | rand_mask = (~torch.from_numpy(rand_mask)) * mask
68 | return rand_mask
69 |
70 | def draw_eval(self, mask=None, box=None):
71 | stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
72 | preset = Scribble.get_stroke_preset(stroke_preset_name)
73 | nStroke = min(self.eval_stroke, mask.sum().item())
74 | h,w = mask.shape
75 | points = self.get_random_points_from_mask(mask, n=nStroke)
76 | rand_masks = []
77 | for i in range(len(points)):
78 | rand_mask = get_mask_by_input_strokes(
79 | init_points=points[:i+1],
80 | imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset)
81 | rand_mask = (~torch.from_numpy(rand_mask)) * mask
82 | rand_masks += [rand_mask]
83 | return torch.stack(rand_masks)
84 |
85 | @staticmethod
86 | def draw_by_points(points, mask, h, w):
87 | stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0]
88 | preset = Scribble.get_stroke_preset(stroke_preset_name)
89 | rand_mask = get_mask_by_input_strokes(
90 | init_points=points,
91 | imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
92 | rand_masks = (~torch.from_numpy(rand_mask)) * mask
93 | return rand_masks
94 |
95 | def __repr__(self,):
96 | return 'scribble'
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/__init__.py
--------------------------------------------------------------------------------
/demo/seem/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/__init__.py
--------------------------------------------------------------------------------
/demo/seem/examples/corgi1.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/corgi1.webp
--------------------------------------------------------------------------------
/demo/seem/examples/corgi2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/corgi2.jpg
--------------------------------------------------------------------------------
/demo/seem/examples/fries1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/fries1.png
--------------------------------------------------------------------------------
/demo/seem/examples/fries2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/fries2.png
--------------------------------------------------------------------------------
/demo/seem/examples/minecraft1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/minecraft1.jpg
--------------------------------------------------------------------------------
/demo/seem/examples/placeholder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/placeholder.png
--------------------------------------------------------------------------------
/demo/seem/examples/ref_vase.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/ref_vase.JPG
--------------------------------------------------------------------------------
/demo/seem/examples/river1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river1.png
--------------------------------------------------------------------------------
/demo/seem/examples/river1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river1.wav
--------------------------------------------------------------------------------
/demo/seem/examples/river1_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river1_mask.png
--------------------------------------------------------------------------------
/demo/seem/examples/river2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/river2.png
--------------------------------------------------------------------------------
/demo/seem/examples/vasedeck.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/vasedeck.mp4
--------------------------------------------------------------------------------
/demo/seem/examples/zebras1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/zebras1.jpg
--------------------------------------------------------------------------------
/demo/seem/examples/zebras2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/demo/seem/examples/zebras2.jpg
--------------------------------------------------------------------------------
/demo/seem/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .interactive import interactive_infer_video, interactive_infer_image
--------------------------------------------------------------------------------
/entry.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import torch
11 | import logging
12 | import wandb
13 |
14 | from utils.arguments import load_opt_command
15 |
16 | logging.basicConfig(level=logging.INFO)
17 | logger = logging.getLogger(__name__)
18 |
19 | def init_wandb(args, job_dir, entity='xueyanz', project='xdecoder', job_name='tmp'):
20 | wandb_dir = os.path.join(job_dir, 'wandb')
21 | os.makedirs(wandb_dir, exist_ok=True)
22 | runid = None
23 | if os.path.exists(f"{wandb_dir}/runid.txt"):
24 | runid = open(f"{wandb_dir}/runid.txt").read()
25 |
26 | wandb.init(project=project,
27 | name=job_name,
28 | dir=wandb_dir,
29 | entity=entity,
30 | resume="allow",
31 | id=runid,
32 | config={"hierarchical": True},)
33 |
34 | open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id)
35 | wandb.config.update({k: args[k] for k in args if k not in wandb.config})
36 |
37 | def main(args=None):
38 | '''
39 | [Main function for the entry point]
40 | 1. Set environment variables for distributed training.
41 | 2. Load the config file and set up the trainer.
42 | '''
43 |
44 | opt, cmdline_args = load_opt_command(args)
45 | command = cmdline_args.command
46 |
47 | if cmdline_args.user_dir:
48 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
49 | opt['base_path'] = absolute_user_dir
50 |
51 | # update_opt(opt, command)
52 | world_size = 1
53 | if 'OMPI_COMM_WORLD_SIZE' in os.environ:
54 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
55 |
56 | if opt['TRAINER'] == 'xdecoder':
57 | from trainer import XDecoder_Trainer as Trainer
58 | else:
59 | assert False, "The trainer type: {} is not defined!".format(opt['TRAINER'])
60 |
61 | trainer = Trainer(opt)
62 | os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL'
63 |
64 | if command == "train":
65 | if opt['rank'] == 0 and opt['WANDB']:
66 | wandb.login(key=os.environ['WANDB_KEY'])
67 | init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder)
68 | trainer.train()
69 | elif command == "evaluate":
70 | trainer.eval()
71 | else:
72 | raise ValueError(f"Unknown command: {command}")
73 |
74 | if __name__ == "__main__":
75 | main()
76 | sys.exit(0)
77 |
--------------------------------------------------------------------------------
/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/__init__.py
--------------------------------------------------------------------------------
/inference/images/animals.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/animals.png
--------------------------------------------------------------------------------
/inference/images/apples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/apples.jpg
--------------------------------------------------------------------------------
/inference/images/coco/000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/000.jpg
--------------------------------------------------------------------------------
/inference/images/coco/001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/001.jpg
--------------------------------------------------------------------------------
/inference/images/coco/002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/002.jpg
--------------------------------------------------------------------------------
/inference/images/coco/003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/coco/003.jpg
--------------------------------------------------------------------------------
/inference/images/fruit.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/fruit.jpg
--------------------------------------------------------------------------------
/inference/images/landscape.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/landscape.jpg
--------------------------------------------------------------------------------
/inference/images/mountain.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/mountain.jpeg
--------------------------------------------------------------------------------
/inference/images/owls.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/owls.jpeg
--------------------------------------------------------------------------------
/inference/images/penguin.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/penguin.jpeg
--------------------------------------------------------------------------------
/inference/images/region_retrieval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/region_retrieval.png
--------------------------------------------------------------------------------
/inference/images/rose.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/rose.webp
--------------------------------------------------------------------------------
/inference/images/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/street.jpg
--------------------------------------------------------------------------------
/inference/images/teaser_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/inference/images/teaser_new.png
--------------------------------------------------------------------------------
/inference/xdecoder/infer_captioning.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 |
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 |
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(0)
18 | import cv2
19 |
20 | import torch
21 | from torchvision import transforms
22 |
23 | from utils.arguments import load_opt_command
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.structures import BitMasks
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from detectron2.utils.colormap import random_color
29 | from utils.visualizer import Visualizer
30 | from utils.distributed import init_distributed
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | def main(args=None):
36 | '''
37 | Main execution point for PyLearn.
38 | '''
39 |
40 | opt, cmdline_args = load_opt_command(args)
41 | if cmdline_args.user_dir:
42 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
43 | opt['base_path'] = absolute_user_dir
44 | opt = init_distributed(opt)
45 |
46 | # META DATA
47 | pretrained_pth = os.path.join(opt['RESUME_FROM'])
48 | if 'novg' not in pretrained_pth:
49 | assert False, "Using the ckpt without visual genome training data will be much better."
50 | output_root = './output'
51 | image_pth = 'inference/images/mountain.jpeg'
52 |
53 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
54 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background"], is_eval=False)
55 |
56 | t = []
57 | t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
58 | transform = transforms.Compose(t)
59 |
60 | with torch.no_grad():
61 | image_ori = Image.open(image_pth).convert("RGB")
62 | width = image_ori.size[0]
63 | height = image_ori.size[1]
64 | image = transform(image_ori)
65 | image = np.asarray(image)
66 | image_ori = np.asarray(image_ori)
67 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
68 |
69 | batch_inputs = [{'image': images, 'height': height, 'width': width, 'image_id': 0}]
70 | outputs = model.model.evaluate_captioning(batch_inputs)
71 | text = outputs[-1]['captioning_text']
72 |
73 | image_ori = image_ori[:,:,::-1].copy()
74 | cv2.rectangle(image_ori, (0, 0), (width, 60), (0,0,0), -1)
75 | font = cv2.FONT_HERSHEY_DUPLEX
76 | fontScale = 1.2
77 | thickness = 2
78 | lineType = 2
79 | bottomLeftCornerOfText = (10, 40)
80 | fontColor = [255,255,255]
81 | cv2.putText(image_ori, text,
82 | bottomLeftCornerOfText,
83 | font,
84 | fontScale,
85 | fontColor,
86 | thickness,
87 | lineType)
88 |
89 | if not os.path.exists(output_root):
90 | os.makedirs(output_root)
91 | cv2.imwrite(os.path.join(output_root, 'captioning.png'), image_ori)
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 | sys.exit(0)
--------------------------------------------------------------------------------
/inference/xdecoder/infer_instseg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 |
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 |
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(2)
18 |
19 | import torch
20 | from torchvision import transforms
21 |
22 | from utils.arguments import load_opt_command
23 |
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.structures import BitMasks
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from detectron2.utils.colormap import random_color
29 | from utils.visualizer import Visualizer
30 | from utils.distributed import init_distributed
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | def main(args=None):
36 | '''
37 | Main execution point for PyLearn.
38 | '''
39 | opt, cmdline_args = load_opt_command(args)
40 | if cmdline_args.user_dir:
41 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
42 | opt['base_path'] = absolute_user_dir
43 |
44 | opt = init_distributed(opt)
45 |
46 | # META DATA
47 | pretrained_pth = os.path.join(opt['RESUME_FROM'])
48 | output_root = './output'
49 | image_pth = 'inference/images/owls.jpeg'
50 |
51 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
52 |
53 | t = []
54 | t.append(transforms.Resize(800, interpolation=Image.BICUBIC))
55 | transform = transforms.Compose(t)
56 |
57 | thing_classes = ["owl"]
58 | thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))]
59 | thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}
60 |
61 | MetadataCatalog.get("demo").set(
62 | thing_colors=thing_colors,
63 | thing_classes=thing_classes,
64 | thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
65 | )
66 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=False)
67 | metadata = MetadataCatalog.get('demo')
68 | model.model.metadata = metadata
69 | model.model.sem_seg_head.num_classes = len(thing_classes)
70 |
71 | with torch.no_grad():
72 | image_ori = Image.open(image_pth).convert('RGB')
73 | width = image_ori.size[0]
74 | height = image_ori.size[1]
75 | image = transform(image_ori)
76 | image = np.asarray(image)
77 | image_ori = np.asarray(image_ori)
78 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
79 |
80 | batch_inputs = [{'image': images, 'height': height, 'width': width}]
81 | outputs = model.forward(batch_inputs)
82 | visual = Visualizer(image_ori, metadata=metadata)
83 |
84 | inst_seg = outputs[-1]['instances']
85 | inst_seg.pred_masks = inst_seg.pred_masks.cpu()
86 | inst_seg.pred_boxes = BitMasks(inst_seg.pred_masks > 0).get_bounding_boxes()
87 | demo = visual.draw_instance_predictions(inst_seg) # rgb Image
88 |
89 | if not os.path.exists(output_root):
90 | os.makedirs(output_root)
91 | demo.save(os.path.join(output_root, 'inst.png'))
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 | sys.exit(0)
--------------------------------------------------------------------------------
/inference/xdecoder/infer_panoseg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 |
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 |
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(1)
18 |
19 | import torch
20 | from torchvision import transforms
21 |
22 | from utils.arguments import load_opt_command
23 |
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.utils.colormap import random_color
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from utils.visualizer import Visualizer
29 | from utils.distributed import init_distributed
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | def main(args=None):
35 | '''
36 | Main execution point for PyLearn.
37 | '''
38 | opt, cmdline_args = load_opt_command(args)
39 | if cmdline_args.user_dir:
40 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
41 | opt['base_path'] = absolute_user_dir
42 |
43 | opt = init_distributed(opt)
44 |
45 | # META DATA
46 | pretrained_pth = os.path.join(opt['RESUME_FROM'])
47 | output_root = './output'
48 | image_pth = 'inference/images/street.jpg'
49 |
50 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
51 |
52 | t = []
53 | t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
54 | transform = transforms.Compose(t)
55 |
56 | thing_classes = ['car','person','traffic light', 'truck', 'motorcycle']
57 | stuff_classes = ['building','sky','street','tree','rock','sidewalk']
58 | thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))]
59 | stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(stuff_classes))]
60 | thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}
61 | stuff_dataset_id_to_contiguous_id = {x+len(thing_classes):x for x in range(len(stuff_classes))}
62 |
63 | MetadataCatalog.get("demo").set(
64 | thing_colors=thing_colors,
65 | thing_classes=thing_classes,
66 | thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
67 | stuff_colors=stuff_colors,
68 | stuff_classes=stuff_classes,
69 | stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id,
70 | )
71 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + stuff_classes + ["background"], is_eval=False)
72 | metadata = MetadataCatalog.get('demo')
73 | model.model.metadata = metadata
74 | model.model.sem_seg_head.num_classes = len(thing_classes + stuff_classes)
75 |
76 | with torch.no_grad():
77 | image_ori = Image.open(image_pth).convert("RGB")
78 | width = image_ori.size[0]
79 | height = image_ori.size[1]
80 | image = transform(image_ori)
81 | image = np.asarray(image)
82 | image_ori = np.asarray(image_ori)
83 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
84 |
85 | batch_inputs = [{'image': images, 'height': height, 'width': width}]
86 | outputs = model.forward(batch_inputs)
87 | visual = Visualizer(image_ori, metadata=metadata)
88 |
89 | pano_seg = outputs[-1]['panoptic_seg'][0]
90 | pano_seg_info = outputs[-1]['panoptic_seg'][1]
91 |
92 | for i in range(len(pano_seg_info)):
93 | if pano_seg_info[i]['category_id'] in metadata.thing_dataset_id_to_contiguous_id.keys():
94 | pano_seg_info[i]['category_id'] = metadata.thing_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']]
95 | else:
96 | pano_seg_info[i]['isthing'] = False
97 | pano_seg_info[i]['category_id'] = metadata.stuff_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']]
98 |
99 | demo = visual.draw_panoptic_seg(pano_seg.cpu(), pano_seg_info) # rgb Image
100 |
101 | if not os.path.exists(output_root):
102 | os.makedirs(output_root)
103 | demo.save(os.path.join(output_root, 'pano.png'))
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 | sys.exit(0)
--------------------------------------------------------------------------------
/inference/xdecoder/infer_refseg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import json
11 | import logging
12 |
13 | pth = '/'.join(sys.path[0].split('/')[:-1])
14 | sys.path.insert(0, pth)
15 |
16 | from PIL import Image
17 | import numpy as np
18 | np.random.seed(27)
19 |
20 | import torch
21 | from torchvision import transforms
22 |
23 | from utils.arguments import load_opt_command
24 |
25 | from detectron2.data import MetadataCatalog
26 | from detectron2.utils.colormap import random_color
27 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
28 | from modeling.BaseModel import BaseModel
29 | from modeling import build_model
30 | from utils.visualizer import Visualizer
31 | from utils.distributed import init_distributed
32 |
33 | # logging.basicConfig(level = logging.INFO)
34 | logger = logging.getLogger(__name__)
35 |
36 |
37 | def main(args=None):
38 | '''
39 | Main execution point for PyLearn.
40 | '''
41 | opt, cmdline_args = load_opt_command(args)
42 | if cmdline_args.user_dir:
43 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
44 | opt['base_path'] = absolute_user_dir
45 |
46 | opt = init_distributed(opt)
47 |
48 | # META DATA
49 | pretrained_pth = os.path.join(opt['RESUME_FROM'])
50 | output_root = './output'
51 | image_pth = 'inference/images/fruit.jpg'
52 |
53 | text = [['The larger watermelon.'], ['The front white flower.'], ['White tea pot.'], ['Flower bunch.'], ['white vase.'], ['The left peach.'], ['The brown knife.']]
54 |
55 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
56 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=False)
57 |
58 | t = []
59 | t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
60 | transform = transforms.Compose(t)
61 |
62 | metadata = MetadataCatalog.get('ade20k_panoptic_train')
63 | model.model.metadata = metadata
64 |
65 | with torch.no_grad():
66 | image_ori = Image.open(image_pth)
67 | width = image_ori.size[0]
68 | height = image_ori.size[1]
69 | image = transform(image_ori)
70 | image = np.asarray(image)
71 | image_ori = np.asarray(image_ori)
72 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
73 |
74 | batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': text}}]
75 | outputs = model.model.evaluate_grounding(batch_inputs, None)
76 | visual = Visualizer(image_ori, metadata=metadata)
77 |
78 | grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy()
79 | for idx, mask in enumerate(grd_mask):
80 | demo = visual.draw_binary_mask(mask, color=random_color(rgb=True, maximum=1).astype(np.int).tolist(), text=text[idx], alpha=0.3)
81 |
82 | output_folder = os.path.join(os.path.join(output_root))
83 | if not os.path.exists(output_folder):
84 | os.makedirs(output_folder)
85 | demo.save(os.path.join(output_folder, 'refseg.png'))
86 |
87 |
88 | if __name__ == "__main__":
89 | main()
90 | sys.exit(0)
--------------------------------------------------------------------------------
/inference/xdecoder/infer_semseg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 |
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 |
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(1)
18 |
19 | import torch
20 | from torchvision import transforms
21 |
22 | from utils.arguments import load_opt_command
23 |
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.utils.colormap import random_color
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from utils.visualizer import Visualizer
29 | from utils.distributed import init_distributed
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | def main(args=None):
35 | '''
36 | Main execution point for PyLearn.
37 | '''
38 | opt, cmdline_args = load_opt_command(args)
39 | if cmdline_args.user_dir:
40 | absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
41 | opt['base_path'] = absolute_user_dir
42 | opt = init_distributed(opt)
43 |
44 | # META DATA
45 | pretrained_pth = os.path.join(opt['RESUME_FROM'])
46 | output_root = './output'
47 | image_pth = 'inference/images/animals.png'
48 |
49 | model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
50 |
51 | t = []
52 | t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
53 | transform = transforms.Compose(t)
54 |
55 | stuff_classes = ['zebra','antelope','giraffe','ostrich','sky','water','grass','sand','tree']
56 | stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(stuff_classes))]
57 | stuff_dataset_id_to_contiguous_id = {x:x for x in range(len(stuff_classes))}
58 |
59 | MetadataCatalog.get("demo").set(
60 | stuff_colors=stuff_colors,
61 | stuff_classes=stuff_classes,
62 | stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id,
63 | )
64 | model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(stuff_classes + ["background"], is_eval=True)
65 | metadata = MetadataCatalog.get('demo')
66 | model.model.metadata = metadata
67 | model.model.sem_seg_head.num_classes = len(stuff_classes)
68 |
69 | with torch.no_grad():
70 | image_ori = Image.open(image_pth).convert("RGB")
71 | width = image_ori.size[0]
72 | height = image_ori.size[1]
73 | image = transform(image_ori)
74 | image = np.asarray(image)
75 | image_ori = np.asarray(image_ori)
76 | images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
77 |
78 | batch_inputs = [{'image': images, 'height': height, 'width': width}]
79 | outputs = model.forward(batch_inputs)
80 | visual = Visualizer(image_ori, metadata=metadata)
81 |
82 | sem_seg = outputs[-1]['sem_seg'].max(0)[1]
83 | demo = visual.draw_sem_seg(sem_seg.cpu(), alpha=0.5) # rgb Image
84 |
85 | if not os.path.exists(output_root):
86 | os.makedirs(output_root)
87 | demo.save(os.path.join(output_root, 'sem.png'))
88 |
89 |
90 | if __name__ == "__main__":
91 | main()
92 | sys.exit(0)
--------------------------------------------------------------------------------
/modeling/BaseModel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from utils.model import align_and_update_state_dicts
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class BaseModel(nn.Module):
13 | def __init__(self, opt, module: nn.Module):
14 | super(BaseModel, self).__init__()
15 | self.opt = opt
16 | self.model = module
17 |
18 | def forward(self, *inputs, **kwargs):
19 | outputs = self.model(*inputs, **kwargs)
20 | return outputs
21 |
22 | def save_pretrained(self, save_dir):
23 | torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
24 |
25 | def from_pretrained(self, load_dir):
26 | state_dict = torch.load(load_dir, map_location=self.opt['device'])
27 | state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
28 | self.model.load_state_dict(state_dict, strict=False)
29 | return self
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .architectures import build_model
--------------------------------------------------------------------------------
/modeling/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | from .xdecoder_model import *
2 | from .seem_model_v0 import *
3 | from .seem_model_v1 import *
4 | from .seem_model_demo import *
5 | from .build import build_model
--------------------------------------------------------------------------------
/modeling/architectures/build.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def build_model(config, **kwargs):
5 | model_name = config['MODEL']['NAME']
6 |
7 | if not is_model(model_name):
8 | raise ValueError(f'Unkown model: {model_name}')
9 |
10 | return model_entrypoints(model_name)(config, **kwargs)
11 |
12 | def register_model(fn):
13 | module_name_split = fn.__module__.split('.')
14 | model_name = module_name_split[-1]
15 | _model_entrypoints[model_name] = fn
16 | return fn
17 |
18 | def model_entrypoints(model_name):
19 | return _model_entrypoints[model_name]
20 |
21 | def is_model(model_name):
22 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/architectures/registry.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 | def register_model(fn):
4 | module_name_split = fn.__module__.split('.')
5 | model_name = module_name_split[-1]
6 | _model_entrypoints[model_name] = fn
7 | return fn
8 |
9 | def model_entrypoints(model_name):
10 | return _model_entrypoints[model_name]
11 |
12 | def is_model(model_name):
13 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_backbone
2 |
3 | from .focal import *
4 | from .focal_dw import *
5 | from .backbone import *
--------------------------------------------------------------------------------
/modeling/backbone/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch.nn as nn
3 |
4 | from detectron2.modeling import ShapeSpec
5 |
6 | # from ..layers import ShapeSpec
7 |
8 | __all__ = ["Backbone"]
9 |
10 |
11 | class Backbone(nn.Module):
12 | """
13 | Abstract base class for network backbones.
14 | """
15 |
16 | def __init__(self):
17 | """
18 | The `__init__` method of any subclass can specify its own set of arguments.
19 | """
20 | super().__init__()
21 |
22 | def forward(self):
23 | """
24 | Subclasses must override this method, but adhere to the same return type.
25 |
26 | Returns:
27 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
28 | """
29 | pass
30 |
31 | @property
32 | def size_divisibility(self) -> int:
33 | """
34 | Some backbones require the input height and width to be divisible by a
35 | specific integer. This is typically true for encoder / decoder type networks
36 | with lateral connection (e.g., FPN) for which feature maps need to match
37 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
38 | input size divisibility is required.
39 | """
40 | return 0
41 |
42 | def output_shape(self):
43 | """
44 | Returns:
45 | dict[str->ShapeSpec]
46 | """
47 | # this is a backward-compatible default
48 | return {
49 | name: ShapeSpec(
50 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
51 | )
52 | for name in self._out_features
53 | }
54 |
--------------------------------------------------------------------------------
/modeling/backbone/build.py:
--------------------------------------------------------------------------------
1 | from .registry import model_entrypoints
2 | from .registry import is_model
3 |
4 | from .backbone import *
5 |
6 | def build_backbone(config, **kwargs):
7 | model_name = config['MODEL']['BACKBONE']['NAME']
8 | if not is_model(model_name):
9 | raise ValueError(f'Unkown model: {model_name}')
10 |
11 | return model_entrypoints(model_name)(config, **kwargs)
--------------------------------------------------------------------------------
/modeling/backbone/registry.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def register_backbone(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 | _model_entrypoints[model_name] = fn
8 | return fn
9 |
10 | def model_entrypoints(model_name):
11 | return _model_entrypoints[model_name]
12 |
13 | def is_model(model_name):
14 | return model_name in _model_entrypoints
15 |
--------------------------------------------------------------------------------
/modeling/body/__init__.py:
--------------------------------------------------------------------------------
1 | from .xdecoder_head import *
2 | from .build import *
3 |
4 | def build_xdecoder_head(config, *args, **kwargs):
5 | model_name = config['MODEL']['HEAD']
6 | if not is_model(model_name):
7 | raise ValueError(f'Unkown model: {model_name}')
8 |
9 | body = model_entrypoints(model_name)(config, *args, **kwargs)
10 | return body
--------------------------------------------------------------------------------
/modeling/body/build.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 | def register_body(fn):
4 | module_name_split = fn.__module__.split('.')
5 | model_name = module_name_split[-1]
6 | _model_entrypoints[model_name] = fn
7 | return fn
8 |
9 | def model_entrypoints(model_name):
10 | return _model_entrypoints[model_name]
11 |
12 | def is_model(model_name):
13 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/body/decoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_decoder
2 | from .xdecoder import *
--------------------------------------------------------------------------------
/modeling/body/decoder/build.py:
--------------------------------------------------------------------------------
1 | from .registry import model_entrypoints
2 | from .registry import is_model
3 |
4 |
5 | def build_decoder(config, *args, **kwargs):
6 | model_name = config['MODEL']['DECODER']['NAME']
7 |
8 | if not is_model(model_name):
9 | raise ValueError(f'Unkown model: {model_name}')
10 |
11 | return model_entrypoints(model_name)(config, *args, **kwargs)
--------------------------------------------------------------------------------
/modeling/body/decoder/registry.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 | def register_decoder(fn):
4 | module_name_split = fn.__module__.split('.')
5 | model_name = module_name_split[-1]
6 | _model_entrypoints[model_name] = fn
7 | return fn
8 |
9 | def model_entrypoints(model_name):
10 | return _model_entrypoints[model_name]
11 |
12 | def is_model(model_name):
13 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/body/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_encoder
--------------------------------------------------------------------------------
/modeling/body/encoder/build.py:
--------------------------------------------------------------------------------
1 | from .registry import model_entrypoints
2 | from .registry import is_model
3 |
4 | from .transformer_encoder_fpn import *
5 |
6 | def build_encoder(config, *args, **kwargs):
7 | model_name = config['MODEL']['ENCODER']['NAME']
8 |
9 | if not is_model(model_name):
10 | raise ValueError(f'Unkown model: {model_name}')
11 |
12 | return model_entrypoints(model_name)(config, *args, **kwargs)
--------------------------------------------------------------------------------
/modeling/body/encoder/registry.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 | def register_encoder(fn):
4 | module_name_split = fn.__module__.split('.')
5 | model_name = module_name_split[-1]
6 | _model_entrypoints[model_name] = fn
7 | return fn
8 |
9 | def model_entrypoints(model_name):
10 | return _model_entrypoints[model_name]
11 |
12 | def is_model(model_name):
13 | return model_name in _model_entrypoints
14 |
--------------------------------------------------------------------------------
/modeling/body/registry.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def register_body(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 | _model_entrypoints[model_name] = fn
8 | return fn
9 |
10 | def model_entrypoints(model_name):
11 | return _model_entrypoints[model_name]
12 |
13 | def is_model(model_name):
14 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/body/xdecoder_head.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates.
8 | from typing import Dict
9 |
10 | from torch import nn
11 |
12 | from detectron2.layers import ShapeSpec
13 |
14 | from .build import register_body
15 | from ..vision.encoder import build_encoder
16 | from ..interface import build_decoder
17 | from ..utils import configurable
18 |
19 |
20 | class XdecoderHead(nn.Module):
21 |
22 | @configurable
23 | def __init__(
24 | self,
25 | input_shape: Dict[str, ShapeSpec],
26 | *,
27 | num_classes: int,
28 | pixel_decoder: nn.Module,
29 | loss_weight: float = 1.0,
30 | ignore_value: int = -1,
31 | # extra parameters
32 | transformer_predictor: nn.Module,
33 | transformer_in_feature: str,
34 | ):
35 | """
36 | NOTE: this interface is experimental.
37 | Args:
38 | input_shape: shapes (channels and stride) of the input features
39 | num_classes: number of classes to predict
40 | pixel_decoder: the pixel decoder module
41 | loss_weight: loss weight
42 | ignore_value: category id to be ignored during training.
43 | transformer_predictor: the transformer decoder that makes prediction
44 | transformer_in_feature: input feature name to the transformer_predictor
45 | """
46 | super().__init__()
47 |
48 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
49 | self.in_features = [k for k, v in input_shape]
50 | feature_strides = [v.stride for k, v in input_shape]
51 | feature_channels = [v.channels for k, v in input_shape]
52 |
53 | self.ignore_value = ignore_value
54 | self.common_stride = 4
55 | self.loss_weight = loss_weight
56 |
57 | self.pixel_decoder = pixel_decoder
58 | self.predictor = transformer_predictor
59 | self.transformer_in_feature = transformer_in_feature
60 |
61 | self.num_classes = num_classes
62 |
63 | @classmethod
64 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
65 |
66 | in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
67 | enc_cfg = cfg['MODEL']['ENCODER']
68 | dec_cfg = cfg['MODEL']['DECODER']
69 |
70 | # figure out in_channels to transformer predictor
71 | if in_features_type == "transformer_encoder":
72 | transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
73 | elif in_features_type == "pixel_embedding":
74 | transformer_predictor_in_channels = enc_cfg['MASK_DIM']
75 | elif in_features_type == "multi_scale_pixel_decoder":
76 | transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
77 | else:
78 | transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
79 |
80 | return {
81 | "input_shape": {
82 | k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
83 | },
84 | "ignore_value": enc_cfg['IGNORE_VALUE'],
85 | "num_classes": enc_cfg.get('NUM_CLASSES', None),
86 | "pixel_decoder": build_encoder(cfg, input_shape),
87 | "loss_weight": enc_cfg['LOSS_WEIGHT'],
88 | "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
89 | "transformer_predictor": build_decoder(
90 | cfg,
91 | transformer_predictor_in_channels,
92 | lang_encoder,
93 | mask_classification=True,
94 | extra=extra,
95 | ),
96 | }
97 |
98 | def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
99 | return self.layers(features, mask, target_queries, target_vlp, task, extra)
100 |
101 | def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
102 | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
103 |
104 | if self.transformer_in_feature == "multi_scale_pixel_decoder":
105 | predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
106 | else:
107 | if self.transformer_in_feature == "transformer_encoder":
108 | assert (
109 | transformer_encoder_features is not None
110 | ), "Please use the TransformerEncoderPixelDecoder."
111 | predictions = self.predictor(transformer_encoder_features, mask_features, mask)
112 | elif self.transformer_in_feature == "pixel_embedding":
113 | predictions = self.predictor(mask_features, mask_features, mask)
114 | else:
115 | predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
116 | return predictions
117 |
118 |
119 | @register_body
120 | def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
121 | return XdecoderHead(cfg, input_shape, lang_encoder, extra)
--------------------------------------------------------------------------------
/modeling/interface/__init__.py:
--------------------------------------------------------------------------------
1 | from .xdecoder import *
2 | from .seem_v0 import *
3 | from .seem_v1 import *
4 | from .seem_demo import *
5 | from .build import *
6 |
7 | def build_decoder(config, *args, **kwargs):
8 | model_name = config['MODEL']['DECODER']['NAME']
9 |
10 | if not is_model(model_name):
11 | raise ValueError(f'Unkown model: {model_name}')
12 |
13 | return model_entrypoints(model_name)(config, *args, **kwargs)
--------------------------------------------------------------------------------
/modeling/interface/build.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def register_decoder(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 | _model_entrypoints[model_name] = fn
8 | return fn
9 |
10 | def model_entrypoints(model_name):
11 | return _model_entrypoints[model_name]
12 |
13 | def is_model(model_name):
14 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/interface/prototype/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/modeling/interface/prototype/__init__.py
--------------------------------------------------------------------------------
/modeling/language/LangEncoder/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import CLIPTokenizer, CLIPTokenizerFast
2 | from transformers import AutoTokenizer
3 |
4 | from .transformer import *
5 | from .build import *
6 |
7 |
8 | def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
9 | model_name = config_encoder['NAME']
10 |
11 | if not is_lang_encoder(model_name):
12 | raise ValueError(f'Unkown model: {model_name}')
13 |
14 | return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
15 |
16 | def build_tokenizer(config_encoder):
17 | tokenizer = None
18 | os.environ['TOKENIZERS_PARALLELISM'] = 'true'
19 | if config_encoder['TOKENIZER'] == 'clip':
20 | pretrained_tokenizer = config_encoder.get(
21 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
22 | )
23 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
24 | tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
25 | elif config_encoder['TOKENIZER'] == 'clip-fast':
26 | pretrained_tokenizer = config_encoder.get(
27 | 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
28 | )
29 | tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
30 | else:
31 | tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
32 |
33 | return tokenizer
--------------------------------------------------------------------------------
/modeling/language/LangEncoder/build.py:
--------------------------------------------------------------------------------
1 | _lang_encoders = {}
2 |
3 |
4 | def register_lang_encoder(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 |
8 | _lang_encoders[model_name] = fn
9 |
10 | return fn
11 |
12 | def lang_encoders(model_name):
13 | return _lang_encoders[model_name]
14 |
15 | def is_lang_encoder(model_name):
16 | return model_name in _lang_encoders
17 |
--------------------------------------------------------------------------------
/modeling/language/LangEncoder/registry.py:
--------------------------------------------------------------------------------
1 | _lang_encoders = {}
2 |
3 |
4 | def register_lang_encoder(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 |
8 | _lang_encoders[model_name] = fn
9 |
10 | return fn
11 |
12 |
13 | def lang_encoders(model_name):
14 | return _lang_encoders[model_name]
15 |
16 |
17 | def is_lang_encoder(model_name):
18 | return model_name in _lang_encoders
19 |
--------------------------------------------------------------------------------
/modeling/language/__init__.py:
--------------------------------------------------------------------------------
1 | from .vlpencoder import *
2 | from .build import *
3 |
4 | def build_language_encoder(config, **kwargs):
5 | model_name = config['MODEL']['TEXT']['ARCH']
6 |
7 | if not is_model(model_name):
8 | raise ValueError(f'Unkown model: {model_name}')
9 |
10 | return model_entrypoints(model_name)(config, **kwargs)
--------------------------------------------------------------------------------
/modeling/language/build.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def register_model(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 | _model_entrypoints[model_name] = fn
8 | return fn
9 |
10 | def model_entrypoints(model_name):
11 | return _model_entrypoints[model_name]
12 |
13 | def is_model(model_name):
14 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/language/misc.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import nltk
5 | nltk.data.path.append('/mnt/data/nltk_data')
6 | import numpy as np
7 |
8 | from utils.constants import IMAGENET_DEFAULT_TEMPLATES
9 |
10 |
11 | def get_tag(tokenized, tags):
12 | if not isinstance(tags, (list, tuple)):
13 | tags = [tags]
14 | ret = []
15 | for (word, pos) in nltk.pos_tag(tokenized):
16 | for tag in tags:
17 | if pos == tag:
18 | ret.append(word)
19 | return ret
20 |
21 | def get_noun_phrase(tokenized):
22 | # Taken from Su Nam Kim Paper...
23 | grammar = r"""
24 | NBAR:
25 | {*} # Nouns and Adjectives, terminated with Nouns
26 |
27 | NP:
28 | {}
29 | {} # Above, connected with in/of/etc...
30 | """
31 | chunker = nltk.RegexpParser(grammar)
32 |
33 | chunked = chunker.parse(nltk.pos_tag(tokenized))
34 | continuous_chunk = []
35 | current_chunk = []
36 |
37 | for subtree in chunked:
38 | if isinstance(subtree, nltk.Tree):
39 | current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
40 | elif current_chunk:
41 | named_entity = ' '.join(current_chunk)
42 | if named_entity not in continuous_chunk:
43 | continuous_chunk.append(named_entity)
44 | current_chunk = []
45 | else:
46 | continue
47 |
48 | return continuous_chunk
49 |
50 | def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
51 | tokenized = nltk.word_tokenize(text)
52 |
53 | if random.random() >= phrase_prob:
54 | nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
55 | else:
56 | nouns = get_noun_phrase(tokenized)
57 |
58 |
59 | prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
60 |
61 | if append_text:
62 | prompt_texts += [text]
63 | nouns += [text]
64 |
65 | return prompt_texts, nouns
--------------------------------------------------------------------------------
/modeling/language/registry.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 | def register_model(fn):
4 | module_name_split = fn.__module__.split('.')
5 | model_name = module_name_split[-1]
6 | _model_entrypoints[model_name] = fn
7 | return fn
8 |
9 | def model_entrypoints(model_name):
10 | return _model_entrypoints[model_name]
11 |
12 | def is_model(model_name):
13 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .point_features import *
2 | from .position_encoding import *
3 | from .postprocessing import *
4 | from .attention import *
5 | from .criterion import *
6 | from .matcher import *
--------------------------------------------------------------------------------
/modeling/modules/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3 | """
4 | Various positional encodings for the transformer.
5 | """
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 |
11 |
12 | class PositionEmbeddingSine(nn.Module):
13 | """
14 | This is a more standard version of the position embedding, very similar to the one
15 | used by the Attention is all you need paper, generalized to work on images.
16 | """
17 |
18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19 | super().__init__()
20 | self.num_pos_feats = num_pos_feats
21 | self.temperature = temperature
22 | self.normalize = normalize
23 | if scale is not None and normalize is False:
24 | raise ValueError("normalize should be True if scale is passed")
25 | if scale is None:
26 | scale = 2 * math.pi
27 | self.scale = scale
28 |
29 | def forward(self, x, mask=None):
30 | if mask is None:
31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32 | not_mask = ~mask
33 | y_embed = not_mask.cumsum(1, dtype=x.dtype)
34 | x_embed = not_mask.cumsum(2, dtype=x.dtype)
35 | if self.normalize:
36 | eps = 1e-6
37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39 |
40 | dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42 |
43 | pos_x = x_embed[:, :, :, None] / dim_t
44 | pos_y = y_embed[:, :, :, None] / dim_t
45 | pos_x = torch.stack(
46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
47 | ).flatten(3)
48 | pos_y = torch.stack(
49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
50 | ).flatten(3)
51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52 | return pos
53 |
54 | def __repr__(self, _repr_indent=4):
55 | head = "Positional encoding " + self.__class__.__name__
56 | body = [
57 | "num_pos_feats: {}".format(self.num_pos_feats),
58 | "temperature: {}".format(self.temperature),
59 | "normalize: {}".format(self.normalize),
60 | "scale: {}".format(self.scale),
61 | ]
62 | # _repr_indent = 4
63 | lines = [head] + [" " * _repr_indent + line for line in body]
64 | return "\n".join(lines)
65 |
--------------------------------------------------------------------------------
/modeling/modules/postprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | from torch.nn import functional as F
4 |
5 | from detectron2.structures import Instances, ROIMasks
6 |
7 |
8 | # perhaps should rename to "resize_instance"
9 | def detector_postprocess(
10 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
11 | ):
12 | """
13 | Resize the output instances.
14 | The input images are often resized when entering an object detector.
15 | As a result, we often need the outputs of the detector in a different
16 | resolution from its inputs.
17 |
18 | This function will resize the raw outputs of an R-CNN detector
19 | to produce outputs according to the desired output resolution.
20 |
21 | Args:
22 | results (Instances): the raw outputs from the detector.
23 | `results.image_size` contains the input image resolution the detector sees.
24 | This object might be modified in-place.
25 | output_height, output_width: the desired output resolution.
26 |
27 | Returns:
28 | Instances: the resized output from the model, based on the output resolution
29 | """
30 | if isinstance(output_width, torch.Tensor):
31 | # This shape might (but not necessarily) be tensors during tracing.
32 | # Converts integer tensors to float temporaries to ensure true
33 | # division is performed when computing scale_x and scale_y.
34 | output_width_tmp = output_width.float()
35 | output_height_tmp = output_height.float()
36 | new_size = torch.stack([output_height, output_width])
37 | else:
38 | new_size = (output_height, output_width)
39 | output_width_tmp = output_width
40 | output_height_tmp = output_height
41 |
42 | scale_x, scale_y = (
43 | output_width_tmp / results.image_size[1],
44 | output_height_tmp / results.image_size[0],
45 | )
46 | results = Instances(new_size, **results.get_fields())
47 |
48 | if results.has("pred_boxes"):
49 | output_boxes = results.pred_boxes
50 | elif results.has("proposal_boxes"):
51 | output_boxes = results.proposal_boxes
52 | else:
53 | output_boxes = None
54 | assert output_boxes is not None, "Predictions must contain boxes!"
55 |
56 | output_boxes.scale(scale_x, scale_y)
57 | output_boxes.clip(results.image_size)
58 |
59 | results = results[output_boxes.nonempty()]
60 |
61 | if results.has("pred_masks"):
62 | if isinstance(results.pred_masks, ROIMasks):
63 | roi_masks = results.pred_masks
64 | else:
65 | # pred_masks is a tensor of shape (N, 1, M, M)
66 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :])
67 | results.pred_masks = roi_masks.to_bitmasks(
68 | results.pred_boxes, output_height, output_width, mask_threshold
69 | ).tensor # TODO return ROIMasks/BitMask object in the future
70 |
71 | if results.has("pred_keypoints"):
72 | results.pred_keypoints[:, :, 0] *= scale_x
73 | results.pred_keypoints[:, :, 1] *= scale_y
74 |
75 | return results
76 |
77 | def bbox_postprocess(result, input_size, img_size, output_height, output_width):
78 | """
79 | result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h]
80 | """
81 | if result is None:
82 | return None
83 |
84 | scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device)
85 | result = result.sigmoid() * scale
86 | x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2
87 | h,w = img_size
88 |
89 | x1 = x1.clamp(min=0, max=w)
90 | y1 = y1.clamp(min=0, max=h)
91 | x2 = x2.clamp(min=0, max=w)
92 | y2 = y2.clamp(min=0, max=h)
93 |
94 | box = torch.stack([x1,y1,x2,y2]).permute(1,0)
95 | scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device)
96 | box = box*scale
97 | return box
98 |
99 | def sem_seg_postprocess(result, img_size, output_height, output_width):
100 | """
101 | Return semantic segmentation predictions in the original resolution.
102 |
103 | The input images are often resized when entering semantic segmentor. Moreover, in same
104 | cases, they also padded inside segmentor to be divisible by maximum network stride.
105 | As a result, we often need the predictions of the segmentor in a different
106 | resolution from its inputs.
107 |
108 | Args:
109 | result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W),
110 | where C is the number of classes, and H, W are the height and width of the prediction.
111 | img_size (tuple): image size that segmentor is taking as input.
112 | output_height, output_width: the desired output resolution.
113 |
114 | Returns:
115 | semantic segmentation prediction (Tensor): A tensor of the shape
116 | (C, output_height, output_width) that contains per-pixel soft predictions.
117 | """
118 | result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1)
119 | result = F.interpolate(
120 | result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True
121 | )[0]
122 | return result
123 |
--------------------------------------------------------------------------------
/modeling/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import *
2 | from .misc import *
3 | from .interactive import *
4 | from .attention import *
--------------------------------------------------------------------------------
/modeling/utils/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 | def box_xywh_to_xyxy(x):
23 | x0, y0, x1, y1 = x.unbind(-1)
24 | b = [x0, y0, (x0 + x1), (y0 + y1)]
25 | return torch.stack(b, dim=-1)
26 |
27 |
28 | # modified from torchvision to also return the union
29 | def box_iou(boxes1, boxes2):
30 | area1 = box_area(boxes1)
31 | area2 = box_area(boxes2)
32 |
33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
35 |
36 | wh = (rb - lt).clamp(min=0) # [N,M,2]
37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
38 |
39 | union = area1[:, None] + area2 - inter
40 |
41 | iou = inter / (union+1e-6)
42 | return iou, union
43 |
44 |
45 | def generalized_box_iou(boxes1, boxes2):
46 | """
47 | Generalized IoU from https://giou.stanford.edu/
48 |
49 | The boxes should be in [x0, y0, x1, y1] format
50 |
51 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
52 | and M = len(boxes2)
53 | """
54 | # degenerate boxes gives inf / nan results
55 | # so do an early check
56 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
57 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
58 | iou, union = box_iou(boxes1, boxes2)
59 |
60 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
61 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
62 |
63 | wh = (rb - lt).clamp(min=0) # [N,M,2]
64 | area = wh[:, :, 0] * wh[:, :, 1]
65 |
66 | return iou - (area - union) / (area+1e-6)
67 |
68 |
69 | def masks_to_boxes(masks):
70 | """Compute the bounding boxes around the provided masks
71 |
72 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
73 |
74 | Returns a [N, 4] tensors, with the boxes in xyxy format
75 | """
76 | if masks.numel() == 0:
77 | return torch.zeros((0, 4), device=masks.device)
78 |
79 | h, w = masks.shape[-2:]
80 |
81 | y = torch.arange(0, h, dtype=torch.float)
82 | x = torch.arange(0, w, dtype=torch.float)
83 | y, x = torch.meshgrid(y, x)
84 |
85 | x_mask = (masks * x.unsqueeze(0))
86 | x_max = x_mask.flatten(1).max(-1)[0]
87 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
88 |
89 | y_mask = (masks * y.unsqueeze(0))
90 | y_max = y_mask.flatten(1).max(-1)[0]
91 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
92 |
93 | return torch.stack([x_min, y_min, x_max, y_max], 1)
--------------------------------------------------------------------------------
/modeling/utils/interactive.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import math
4 |
5 | import torch
6 | from torch import nn, Tensor
7 | import torch.nn.functional as F
8 |
9 |
10 | def rand_sample(x, divisor, max_len):
11 | # non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
12 | if len(x.nonzero()) == 0:
13 | return x.nonzero().t()
14 |
15 | non_zero_point_index = (x.nonzero()/divisor).t()
16 | mask_ids = non_zero_point_index[0].unique().long()
17 |
18 | # compute probability for each samle
19 | probs = torch.zeros_like(non_zero_point_index[0])
20 | for idx in mask_ids:
21 | prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum()))
22 | probs[non_zero_point_index[0]==idx] = prob
23 |
24 | indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0]
25 | non_zero_point_index = non_zero_point_index[:,indices]
26 | return non_zero_point_index # [n, 512]
27 |
28 | def rand_sample_plain(x, max_len):
29 | if x.shape[1] <= max_len:
30 | return x
31 | else:
32 | rand_idx = torch.randperm(x.shape[1])[:max_len]
33 | return x[:,rand_idx]
34 |
35 | def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed):
36 | src = []
37 | pos = []
38 | size_list = []
39 |
40 | # disable mask, it does not affect performance
41 | for i in range(num_feature_levels):
42 | size_list.append(x[i].shape[-2:])
43 | pos.append(pe_layer(x[i], None).flatten(2))
44 | src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None])
45 |
46 | # flatten NxCxHxW to HWxNxC
47 | pos[-1] = pos[-1].permute(2, 0, 1)
48 | src[-1] = src[-1].permute(2, 0, 1)
49 | return src, pos, size_list
--------------------------------------------------------------------------------
/modeling/vision/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .focal import *
2 | from .focal_dw import *
3 | from .davit import *
4 | from .vit import *
5 | from .backbone import *
6 | from .build import *
7 |
8 |
9 | def build_backbone(config, **kwargs):
10 | model_name = config['MODEL']['BACKBONE']['NAME']
11 | if not is_model(model_name):
12 | raise ValueError(f'Unkown model: {model_name}')
13 |
14 | return model_entrypoints(model_name)(config, **kwargs)
--------------------------------------------------------------------------------
/modeling/vision/backbone/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch.nn as nn
3 |
4 | from detectron2.modeling import ShapeSpec
5 |
6 | # from ..layers import ShapeSpec
7 |
8 | __all__ = ["Backbone"]
9 |
10 |
11 | class Backbone(nn.Module):
12 | """
13 | Abstract base class for network backbones.
14 | """
15 |
16 | def __init__(self):
17 | """
18 | The `__init__` method of any subclass can specify its own set of arguments.
19 | """
20 | super().__init__()
21 |
22 | def forward(self):
23 | """
24 | Subclasses must override this method, but adhere to the same return type.
25 |
26 | Returns:
27 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
28 | """
29 | pass
30 |
31 | @property
32 | def size_divisibility(self) -> int:
33 | """
34 | Some backbones require the input height and width to be divisible by a
35 | specific integer. This is typically true for encoder / decoder type networks
36 | with lateral connection (e.g., FPN) for which feature maps need to match
37 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
38 | input size divisibility is required.
39 | """
40 | return 0
41 |
42 | def output_shape(self):
43 | """
44 | Returns:
45 | dict[str->ShapeSpec]
46 | """
47 | # this is a backward-compatible default
48 | return {
49 | name: ShapeSpec(
50 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
51 | )
52 | for name in self._out_features
53 | }
54 |
--------------------------------------------------------------------------------
/modeling/vision/backbone/build.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def register_backbone(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 | _model_entrypoints[model_name] = fn
8 | return fn
9 |
10 | def model_entrypoints(model_name):
11 | return _model_entrypoints[model_name]
12 |
13 | def is_model(model_name):
14 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/vision/backbone/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer_encoder_fpn import *
2 | try:
3 | from .transformer_encoder_deform import *
4 | except:
5 | print('Deformable Transformer Encoder is not available.')
6 | from .build import *
7 |
8 |
9 | def build_encoder(config, *args, **kwargs):
10 | model_name = config['MODEL']['ENCODER']['NAME']
11 |
12 | if not is_model(model_name):
13 | raise ValueError(f'Unkown model: {model_name}')
14 |
15 | return model_entrypoints(model_name)(config, *args, **kwargs)
--------------------------------------------------------------------------------
/modeling/vision/encoder/build.py:
--------------------------------------------------------------------------------
1 | _model_entrypoints = {}
2 |
3 |
4 | def register_encoder(fn):
5 | module_name_split = fn.__module__.split('.')
6 | model_name = module_name_split[-1]
7 | _model_entrypoints[model_name] = fn
8 | return fn
9 |
10 | def model_entrypoints(model_name):
11 | return _model_entrypoints[model_name]
12 |
13 | def is_model(model_name):
14 | return model_name in _model_entrypoints
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from .ms_deform_attn_func import MSDeformAttnFunction
13 |
14 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from torch.autograd import Function
19 | from torch.autograd.function import once_differentiable
20 |
21 | try:
22 | import MultiScaleDeformableAttention as MSDA
23 | except ModuleNotFoundError as e:
24 | info_string = (
25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26 | "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
27 | "\t`sh make.sh`\n"
28 | )
29 | raise ModuleNotFoundError(info_string)
30 |
31 |
32 | class MSDeformAttnFunction(Function):
33 | @staticmethod
34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35 | ctx.im2col_step = im2col_step
36 | output = MSDA.ms_deform_attn_forward(
37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39 | return output
40 |
41 | @staticmethod
42 | @once_differentiable
43 | def backward(ctx, grad_output):
44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45 | grad_value, grad_sampling_loc, grad_attn_weight = \
46 | MSDA.ms_deform_attn_backward(
47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48 |
49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50 |
51 |
52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53 | # for debug and test only,
54 | # need to use cuda version instead
55 | N_, S_, M_, D_ = value.shape
56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58 | sampling_grids = 2 * sampling_locations - 1
59 | sampling_value_list = []
60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65 | # N_*M_, D_, Lq_, P_
66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67 | mode='bilinear', padding_mode='zeros', align_corners=False)
68 | sampling_value_list.append(sampling_value_l_)
69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72 | return output.transpose(1, 2).contiguous()
73 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | # Copyright (c) Facebook, Inc. and its affiliates.
11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12 |
13 | python setup.py build install --user
14 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from .ms_deform_attn import MSDeformAttn
13 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | import os
13 | import glob
14 |
15 | import torch
16 |
17 | from torch.utils.cpp_extension import CUDA_HOME
18 | from torch.utils.cpp_extension import CppExtension
19 | from torch.utils.cpp_extension import CUDAExtension
20 |
21 | from setuptools import find_packages
22 | from setuptools import setup
23 |
24 | requirements = ["torch", "torchvision"]
25 |
26 | def get_extensions():
27 | this_dir = os.path.dirname(os.path.abspath(__file__))
28 | extensions_dir = os.path.join(this_dir, "src")
29 |
30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33 |
34 | sources = main_file + source_cpu
35 | extension = CppExtension
36 | extra_compile_args = {"cxx": []}
37 | define_macros = []
38 |
39 | # Force cuda since torch ask for a device, not if cuda is in fact available.
40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41 | extension = CUDAExtension
42 | sources += source_cuda
43 | define_macros += [("WITH_CUDA", None)]
44 | extra_compile_args["nvcc"] = [
45 | "-DCUDA_HAS_FP16=1",
46 | "-D__CUDA_NO_HALF_OPERATORS__",
47 | "-D__CUDA_NO_HALF_CONVERSIONS__",
48 | "-D__CUDA_NO_HALF2_OPERATORS__",
49 | ]
50 | else:
51 | if CUDA_HOME is None:
52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53 | else:
54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55 |
56 | sources = [os.path.join(extensions_dir, s) for s in sources]
57 | include_dirs = [extensions_dir]
58 | ext_modules = [
59 | extension(
60 | "MultiScaleDeformableAttention",
61 | sources,
62 | include_dirs=include_dirs,
63 | define_macros=define_macros,
64 | extra_compile_args=extra_compile_args,
65 | )
66 | ]
67 | return ext_modules
68 |
69 | setup(
70 | name="MultiScaleDeformableAttention",
71 | version="1.0",
72 | author="Weijie Su",
73 | url="https://github.com/fundamentalvision/Deformable-DETR",
74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75 | packages=find_packages(exclude=("configs", "tests",)),
76 | ext_modules=get_extensions(),
77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78 | )
79 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #include
17 |
18 | #include
19 | #include
20 |
21 |
22 | at::Tensor
23 | ms_deform_attn_cpu_forward(
24 | const at::Tensor &value,
25 | const at::Tensor &spatial_shapes,
26 | const at::Tensor &level_start_index,
27 | const at::Tensor &sampling_loc,
28 | const at::Tensor &attn_weight,
29 | const int im2col_step)
30 | {
31 | AT_ERROR("Not implement on cpu");
32 | }
33 |
34 | std::vector
35 | ms_deform_attn_cpu_backward(
36 | const at::Tensor &value,
37 | const at::Tensor &spatial_shapes,
38 | const at::Tensor &level_start_index,
39 | const at::Tensor &sampling_loc,
40 | const at::Tensor &attn_weight,
41 | const at::Tensor &grad_output,
42 | const int im2col_step)
43 | {
44 | AT_ERROR("Not implement on cpu");
45 | }
46 |
47 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #pragma once
17 | #include
18 |
19 | at::Tensor
20 | ms_deform_attn_cpu_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step);
27 |
28 | std::vector
29 | ms_deform_attn_cpu_backward(
30 | const at::Tensor &value,
31 | const at::Tensor &spatial_shapes,
32 | const at::Tensor &level_start_index,
33 | const at::Tensor &sampling_loc,
34 | const at::Tensor &attn_weight,
35 | const at::Tensor &grad_output,
36 | const int im2col_step);
37 |
38 |
39 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #pragma once
17 | #include
18 |
19 | at::Tensor ms_deform_attn_cuda_forward(
20 | const at::Tensor &value,
21 | const at::Tensor &spatial_shapes,
22 | const at::Tensor &level_start_index,
23 | const at::Tensor &sampling_loc,
24 | const at::Tensor &attn_weight,
25 | const int im2col_step);
26 |
27 | std::vector ms_deform_attn_cuda_backward(
28 | const at::Tensor &value,
29 | const at::Tensor &spatial_shapes,
30 | const at::Tensor &level_start_index,
31 | const at::Tensor &sampling_loc,
32 | const at::Tensor &attn_weight,
33 | const at::Tensor &grad_output,
34 | const int im2col_step);
35 |
36 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #pragma once
17 |
18 | #include "cpu/ms_deform_attn_cpu.h"
19 |
20 | #ifdef WITH_CUDA
21 | #include "cuda/ms_deform_attn_cuda.h"
22 | #endif
23 |
24 |
25 | at::Tensor
26 | ms_deform_attn_forward(
27 | const at::Tensor &value,
28 | const at::Tensor &spatial_shapes,
29 | const at::Tensor &level_start_index,
30 | const at::Tensor &sampling_loc,
31 | const at::Tensor &attn_weight,
32 | const int im2col_step)
33 | {
34 | if (value.type().is_cuda())
35 | {
36 | #ifdef WITH_CUDA
37 | return ms_deform_attn_cuda_forward(
38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39 | #else
40 | AT_ERROR("Not compiled with GPU support");
41 | #endif
42 | }
43 | AT_ERROR("Not implemented on the CPU");
44 | }
45 |
46 | std::vector
47 | ms_deform_attn_backward(
48 | const at::Tensor &value,
49 | const at::Tensor &spatial_shapes,
50 | const at::Tensor &level_start_index,
51 | const at::Tensor &sampling_loc,
52 | const at::Tensor &attn_weight,
53 | const at::Tensor &grad_output,
54 | const int im2col_step)
55 | {
56 | if (value.type().is_cuda())
57 | {
58 | #ifdef WITH_CUDA
59 | return ms_deform_attn_cuda_backward(
60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61 | #else
62 | AT_ERROR("Not compiled with GPU support");
63 | #endif
64 | }
65 | AT_ERROR("Not implemented on the CPU");
66 | }
67 |
68 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #include "ms_deform_attn.h"
17 |
18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21 | }
22 |
--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 |
16 | import time
17 | import torch
18 | import torch.nn as nn
19 | from torch.autograd import gradcheck
20 |
21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22 |
23 |
24 | N, M, D = 1, 2, 2
25 | Lq, L, P = 2, 2, 2
26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28 | S = sum([(H*W).item() for H, W in shapes])
29 |
30 |
31 | torch.manual_seed(3)
32 |
33 |
34 | @torch.no_grad()
35 | def check_forward_equal_with_pytorch_double():
36 | value = torch.rand(N, S, M, D).cuda() * 0.01
37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40 | im2col_step = 2
41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43 | fwdok = torch.allclose(output_cuda, output_pytorch)
44 | max_abs_err = (output_cuda - output_pytorch).abs().max()
45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46 |
47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48 |
49 |
50 | @torch.no_grad()
51 | def check_forward_equal_with_pytorch_float():
52 | value = torch.rand(N, S, M, D).cuda() * 0.01
53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56 | im2col_step = 2
57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60 | max_abs_err = (output_cuda - output_pytorch).abs().max()
61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62 |
63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64 |
65 |
66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67 |
68 | value = torch.rand(N, S, M, channels).cuda() * 0.01
69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72 | im2col_step = 2
73 | func = MSDeformAttnFunction.apply
74 |
75 | value.requires_grad = grad_value
76 | sampling_locations.requires_grad = grad_sampling_loc
77 | attention_weights.requires_grad = grad_attn_weight
78 |
79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80 |
81 | print(f'* {gradok} check_gradient_numerical(D={channels})')
82 |
83 |
84 | if __name__ == '__main__':
85 | check_forward_equal_with_pytorch_double()
86 | check_forward_equal_with_pytorch_float()
87 |
88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89 | check_gradient_numerical(channels, True, True, True)
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/pipeline/__init__.py
--------------------------------------------------------------------------------
/pipeline/utils/misc.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 | def hook_opt(opt):
7 |
8 | try:
9 | grounding_flag = opt['REF']['INPUT']['SPATIAL']
10 | except:
11 | grounding_flag = False
12 |
13 | if grounding_flag:
14 | opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['grounding'] = ['queries_grounding', 'tokens_grounding', 'tokens_spatial']
15 |
16 | try:
17 | spatial_flag = opt['STROKE_SAMPLER']['EVAL']['GROUNDING']
18 | except:
19 | spatial_flag = False
20 |
21 | if spatial_flag:
22 | opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['spatial'] = ['queries_spatial', 'tokens_spatial', 'memories_spatial', 'tokens_grounding']
23 |
24 | return opt
25 |
26 | # HACK for evalution
27 | def hook_metadata(metadata, name):
28 | return metadata
29 |
30 | # HACK for evalution
31 | def hook_switcher(model, name):
32 | mappings = {}
33 | if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'context_59_val_seg', 'context_459_val_seg', 'voc_2012_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']:
34 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False}
35 | elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name:
36 | mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False}
37 | elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']:
38 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True}
39 | elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']:
40 | mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True}
41 | else:
42 | if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd", "pascalvoc_val_Point", "grounding_coco_entity_val", "vlp_coco_entity_val"]:
43 | assert False, "dataset switcher is not defined"
44 |
45 | for key, value in mappings.items():
46 | if key == 'SEMANTIC_ON':
47 | model.model.semantic_on = value
48 | if key == 'INSTANCE_ON':
49 | model.model.instance_on = value
50 | if key == 'PANOPTIC_ON':
51 | model.model.panoptic_on = value
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .xdecoder_trainer import *
--------------------------------------------------------------------------------
/trainer/distributed_trainer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import logging
10 | from mpi4py import MPI
11 |
12 | import torch
13 |
14 | from .utils.hook import add_hook
15 | from .utils.mpi_adapter import MPIAdapter
16 | from .utils.misc import save_opt_to_yaml
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class DistributedTrainer:
22 | def __init__(self, opt):
23 | self.opt = opt
24 |
25 | # parse environment information for distributed training
26 | adapter = MPIAdapter(self.opt['PORT'])
27 | self.opt['world_size'] = adapter.world_size
28 | self.opt['local_size'] = adapter.local_size
29 | self.opt['rank'] = adapter.rank
30 | self.opt['local_rank'] = adapter.local_rank
31 |
32 | self.set_opt_hook()
33 |
34 | # set up device
35 | if not self.opt['CUDA']:
36 | self.opt['device'] = torch.device("cpu")
37 | logger.info("Using CPU")
38 | else:
39 | torch.cuda.set_device(self.opt['local_rank'])
40 | self.opt['device'] = torch.device("cuda", self.opt['local_rank'])
41 | logger.info("Using CUDA")
42 |
43 | # init distributed training
44 | adapter.log_info()
45 | if torch.distributed.is_available() and self.opt['world_size'] > 1:
46 | adapter.init_process_group(backend='nccl')
47 |
48 | # save config file
49 | self.save_folder = self.opt['SAVE_DIR']
50 |
51 | if self.opt['world_size'] > 1:
52 | torch.distributed.barrier()
53 |
54 | if self.opt['rank'] == 0:
55 | os.makedirs(self.save_folder, exist_ok=True)
56 |
57 | logger.info(f"Save config file to {os.path.join(self.save_folder, 'conf_copy.yaml')}")
58 | save_opt_to_yaml(self.opt, os.path.join(self.save_folder, 'conf_copy.yaml'))
59 |
60 | # ddp: log stats and update learning rate
61 | self.grad_acc_steps = self.opt['GRADIENT_ACCUMULATE_STEP']
62 | logger.info(f"Base learning rate: {self.opt['SOLVER']['BASE_LR']}")
63 | logger.info(f"Number of GPUs: {self.opt['world_size']}")
64 | logger.info(f"Gradient accumulation steps: {self.grad_acc_steps}")
65 |
66 | if self.opt['world_size'] > 1:
67 | add_hook()
68 |
69 | # prepare metadata for save folder
70 | conf_file = self.opt['conf_files'][0]
71 | if 'BASENAME' not in self.opt:
72 | self.opt['BASENAME'] = os.path.basename(conf_file)
73 |
74 | self.init_save_folder()
75 |
76 | def set_opt_hook(self):
77 | # Fill in the default values for required keywords
78 | self.opt['CUDA'] = self.opt.get('CUDA', True) and torch.cuda.is_available()
79 | self.opt['FP16'] = self.opt.get('FP16', False) and self.opt['CUDA']
80 | self.opt['GRADIENT_ACCUMULATE_STEP'] = int(self.opt.get('GRADIENT_ACCUMULATE_STEP', 1))
81 | self.opt['EVAL_PER_UPDATE_NUM'] = int(self.opt.get('EVAL_PER_UPDATE_NUM', 0))
82 | self.opt['LR_SCHEDULER_PARAMS'] = self.opt.get('LR_SCHEDULER_PARAMS', {})
83 |
84 | if 'SAVE_DIR' not in self.opt:
85 | assert False, "Please initialize SAVE_DIR in your config file."
86 | self.opt['SAVE_DIR'] = os.path.normpath(self.opt['SAVE_DIR'])
87 | logger.info(f"Setting SAVE_DIR as {self.opt['SAVE_DIR']}")
88 |
89 | def init_save_folder(self):
90 | """
91 | Initialize the save folder for logs, model, checkpoint, and evaluation.
92 | """
93 | runid = 1
94 |
95 | if self.opt['world_size'] > 1:
96 | torch.distributed.barrier()
97 |
98 | if self.opt['rank'] == 0:
99 | while True:
100 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
101 | try:
102 | os.makedirs(save_folder, exist_ok=False)
103 | break
104 | except FileExistsError:
105 | runid = runid + 1
106 |
107 | if self.opt['world_size'] > 1:
108 | torch.distributed.barrier()
109 |
110 | if self.opt['world_size'] > 1:
111 | runid = 1
112 | while True:
113 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
114 | if not os.path.exists(save_folder):
115 | break
116 | else:
117 | runid += 1
118 |
119 | runid -= 1
120 | save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
121 | # this second os.makedirs() call on all ranks is to force sync the save_folder creation between blobFuse and local fs
122 | os.makedirs(save_folder, exist_ok=True)
123 |
124 | self.save_folder = save_folder
--------------------------------------------------------------------------------
/trainer/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/X-Decoder/165f8a6314ac84f5c36aaab7216f90dd97e38a43/trainer/utils/__init__.py
--------------------------------------------------------------------------------
/trainer/utils/hook.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 | _orig_except_hook = None
7 |
8 |
9 | def _global_except_hook(exctype, value, traceback):
10 | """Catches an unhandled exception and call MPI_Abort()."""
11 | try:
12 | if _orig_except_hook:
13 | _orig_except_hook(exctype, value, traceback)
14 | else:
15 | sys.__excepthook__(exctype, value, traceback)
16 |
17 | finally:
18 | import mpi4py.MPI
19 | rank = mpi4py.MPI.COMM_WORLD.Get_rank()
20 | logger.warning("******************************************")
21 | logger.warning("DefaultTrainer:")
22 | logger.warning(f" Uncaught exception on rank {rank}.")
23 | logger.warning(" Calling MPI_Abort() to shut down MPI...")
24 | logger.warning("******************************************")
25 | logging.shutdown()
26 |
27 | try:
28 | import mpi4py.MPI
29 | mpi4py.MPI.COMM_WORLD.Abort(1)
30 | except Exception as e:
31 | # Something is completely broken...
32 | # There's nothing we can do any more
33 | sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n")
34 | sys.stderr.flush()
35 | raise e
36 |
37 |
38 | def add_hook():
39 | """
40 | Add a global hook function that captures all unhandled exceptions.
41 | The function calls MPI_Abort() to force all processes abort.
42 |
43 | An MPI runtime is expected to kill all of its child processes
44 | if one of them exits abnormally or without calling `MPI_Finalize()`.
45 | However, when a Python program run on `mpi4py`, the MPI runtime
46 | often fails to detect a process failure, and the rest of the processes
47 | hang infinitely.
48 |
49 | See https://github.com/chainer/chainermn/issues/236 and
50 | https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more
51 | information.
52 | """
53 | global _orig_except_hook
54 |
55 | if _orig_except_hook is not None:
56 | logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.")
57 | return
58 |
59 | logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.")
60 | _orig_except_hook = sys.excepthook
61 | sys.excepthook = _global_except_hook
62 |
--------------------------------------------------------------------------------
/trainer/utils/serialization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | from typing import Dict
4 |
5 |
6 | class JSONEncoder(json.JSONEncoder):
7 | def default(self, obj):
8 | if isinstance(obj, np.integer):
9 | return int(obj)
10 | elif isinstance(obj, np.floating):
11 | return float(obj)
12 | elif isinstance(obj, np.ndarray):
13 | return obj.tolist()
14 | else:
15 | return super(JSONEncoder, self).default(obj)
16 |
17 |
18 | def is_jsonable(x, json_encoder=None):
19 | try:
20 | json.dumps(x, cls=json_encoder)
21 | return True
22 | except Exception:
23 | return False
24 |
25 |
26 | def filter_jsonable(data: Dict, json_encoder=None) -> Dict:
27 | return {k: v for k, v in data.items() if is_jsonable(k, json_encoder=json_encoder) and is_jsonable(v, json_encoder=json_encoder)}
--------------------------------------------------------------------------------
/utils/Config.py:
--------------------------------------------------------------------------------
1 | from fvcore.common.config import CfgNode as _CfgNode
2 |
3 |
4 | class CfgNode(_CfgNode):
5 | """
6 | The same as `fvcore.common.config.CfgNode`, but different in:
7 |
8 | 1. Use unsafe yaml loading by default.
9 | Note that this may lead to arbitrary code execution: you must not
10 | load a config file from untrusted sources before manually inspecting
11 | the content of the file.
12 | 2. Support config versioning.
13 | When attempting to merge an old config, it will convert the old config automatically.
14 |
15 | .. automethod:: clone
16 | .. automethod:: freeze
17 | .. automethod:: defrost
18 | .. automethod:: is_frozen
19 | .. automethod:: load_yaml_with_base
20 | .. automethod:: merge_from_list
21 | .. automethod:: merge_from_other_cfg
22 | """
23 |
24 | def merge_from_dict(self, dict):
25 | pass
26 |
27 | node = CfgNode()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt_engineering import *
2 | from .dataset import *
--------------------------------------------------------------------------------
/utils/arguments.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import json
3 | import argparse
4 | import logging
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | def load_config_dict_to_opt(opt, config_dict):
10 | """
11 | Load the key, value pairs from config_dict to opt, overriding existing values in opt
12 | if there is any.
13 | """
14 | if not isinstance(config_dict, dict):
15 | raise TypeError("Config must be a Python dictionary")
16 | for k, v in config_dict.items():
17 | k_parts = k.split('.')
18 | pointer = opt
19 | for k_part in k_parts[:-1]:
20 | if k_part not in pointer:
21 | pointer[k_part] = {}
22 | pointer = pointer[k_part]
23 | assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
24 | ori_value = pointer.get(k_parts[-1])
25 | pointer[k_parts[-1]] = v
26 | if ori_value:
27 | logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
28 |
29 |
30 | def load_opt_from_config_files(conf_files):
31 | """
32 | Load opt from the config files, settings in later files can override those in previous files.
33 |
34 | Args:
35 | conf_files (list): a list of config file paths
36 |
37 | Returns:
38 | dict: a dictionary of opt settings
39 | """
40 | opt = {}
41 | for conf_file in conf_files:
42 | with open(conf_file, encoding='utf-8') as f:
43 | config_dict = yaml.safe_load(f)
44 |
45 | load_config_dict_to_opt(opt, config_dict)
46 |
47 | return opt
48 |
49 |
50 | def load_opt_command(args):
51 | parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')
52 | parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
53 | parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).')
54 | parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
55 | parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
56 | parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)
57 |
58 | cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
59 |
60 | opt = load_opt_from_config_files(cmdline_args.conf_files)
61 |
62 | if cmdline_args.config_overrides:
63 | config_overrides_string = ' '.join(cmdline_args.config_overrides)
64 | logger.warning(f"Command line config overrides: {config_overrides_string}")
65 | config_dict = json.loads(config_overrides_string)
66 | load_config_dict_to_opt(opt, config_dict)
67 |
68 | if cmdline_args.overrides:
69 | assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value"
70 | keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
71 | vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
72 | vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
73 |
74 | types = []
75 | for key in keys:
76 | key = key.split('.')
77 | ele = opt.copy()
78 | while len(key) > 0:
79 | ele = ele[key.pop(0)]
80 | types.append(type(ele))
81 |
82 | config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
83 | load_config_dict_to_opt(opt, config_dict)
84 |
85 | # combine cmdline_args into opt dictionary
86 | for key, val in cmdline_args.__dict__.items():
87 | if val is not None:
88 | opt[key] = val
89 |
90 | return opt, cmdline_args
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | class Entity(object):
3 | def __init__(self, _id, _text, _mask, _interactive, _type, _start_idx, _end_idx, _image=None):
4 | self.id = _id
5 | self.text = _text
6 | self.mask = _mask
7 | self.interactive = _interactive
8 | self.type = _type
9 | self.start_idx = _start_idx
10 | self.end_idx = _end_idx
11 |
12 | self.image = _image
13 |
14 | def split_by_ordered_substrings(sentence, substrings):
15 | results = []
16 | substring_indices = []
17 |
18 | start_index = 0
19 | for i, substring in enumerate(substrings):
20 | # Find the start of the substring in the remaining part of the sentence
21 | index = sentence[start_index:].find(substring)
22 |
23 | if index == -1:
24 | continue
25 |
26 | # Append any text before the substring to the results, including spaces
27 | if index > 0:
28 | results.append(sentence[start_index:start_index+index])
29 | substring_indices.append(None) # No match in the `substrings` list for this segment
30 |
31 | # Append the substring to the results
32 | results.append(substring)
33 | substring_indices.append(i) # Append the index from the `substrings` list
34 | start_index += index + len(substring)
35 |
36 | # If there's any remaining part of the sentence after all substrings, append it to the results
37 | if start_index < len(sentence):
38 | results.append(sentence[start_index:])
39 | substring_indices.append(None) # No match in the `substrings` list for this segment
40 |
41 | return results, substring_indices
42 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import pickle
5 | import subprocess
6 |
7 | from mpi4py import MPI
8 | import torch.distributed as dist
9 |
10 |
11 | def apply_distributed(opt):
12 | if opt['rank'] == 0:
13 | hostname_cmd = ["hostname -I"]
14 | result = subprocess.check_output(hostname_cmd, shell=True)
15 | master_address = result.decode('utf-8').split()[0]
16 | master_port = opt['PORT']
17 | else:
18 | master_address = None
19 | master_port = None
20 |
21 | master_address = MPI.COMM_WORLD.bcast(master_address, root=0)
22 | master_port = MPI.COMM_WORLD.bcast(master_port, root=0)
23 |
24 | if torch.distributed.is_available() and opt['world_size'] > 1:
25 | init_method_url = 'tcp://{}:{}'.format(master_address, master_port)
26 | backend = 'nccl'
27 | world_size = opt['world_size']
28 | rank = opt['rank']
29 | torch.distributed.init_process_group(backend=backend,
30 | init_method=init_method_url,
31 | world_size=world_size,
32 | rank=rank)
33 |
34 | def init_distributed(opt):
35 | opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available()
36 | if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
37 | # application was started without MPI
38 | # default to single node with single process
39 | opt['env_info'] = 'no MPI'
40 | opt['world_size'] = 1
41 | opt['local_size'] = 1
42 | opt['rank'] = 0
43 | opt['local_rank'] = 0
44 | opt['master_address'] = '127.0.0.1'
45 | opt['master_port'] = '8673'
46 | else:
47 | # application was started with MPI
48 | # get MPI parameters
49 | opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE'])
50 | opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
51 | opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK'])
52 | opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
53 |
54 | # set up device
55 | if not opt['CUDA']:
56 | assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend'
57 | opt['device'] = torch.device("cpu")
58 | else:
59 | torch.cuda.set_device(opt['local_rank'])
60 | opt['device'] = torch.device("cuda", opt['local_rank'])
61 |
62 | apply_distributed(opt)
63 | return opt
64 |
65 | def is_main_process():
66 | rank = 0
67 | if 'OMPI_COMM_WORLD_SIZE' in os.environ:
68 | rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
69 |
70 | return rank == 0
71 |
72 | def get_world_size():
73 | if not dist.is_available():
74 | return 1
75 | if not dist.is_initialized():
76 | return 1
77 | return dist.get_world_size()
78 |
79 | def get_rank():
80 | if not dist.is_available():
81 | return 0
82 | if not dist.is_initialized():
83 | return 0
84 | return dist.get_rank()
85 |
86 |
87 | def synchronize():
88 | """
89 | Helper function to synchronize (barrier) among all processes when
90 | using distributed training
91 | """
92 | if not dist.is_available():
93 | return
94 | if not dist.is_initialized():
95 | return
96 | world_size = dist.get_world_size()
97 | rank = dist.get_rank()
98 | if world_size == 1:
99 | return
100 |
101 | def _send_and_wait(r):
102 | if rank == r:
103 | tensor = torch.tensor(0, device="cuda")
104 | else:
105 | tensor = torch.tensor(1, device="cuda")
106 | dist.broadcast(tensor, r)
107 | while tensor.item() == 1:
108 | time.sleep(1)
109 |
110 | _send_and_wait(0)
111 | # now sync on the main process
112 | _send_and_wait(1)
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6 | # --------------------------------------------------------
7 | import math
8 |
9 |
10 |
11 | class AverageMeter(object):
12 | """Computes and stores the average and current value."""
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1, decay=0):
23 | self.val = val
24 | if decay:
25 | alpha = math.exp(-n / decay) # exponential decay over 100 updates
26 | self.sum = alpha * self.sum + (1 - alpha) * val * n
27 | self.count = alpha * self.count + (1 - alpha) * n
28 | else:
29 | self.sum += val * n
30 | self.count += n
31 | self.avg = self.sum / self.count
32 |
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | import pickle
5 | import torch
6 | import torch.nn as nn
7 |
8 | from utils.distributed import is_main_process
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | NORM_MODULES = [
14 | torch.nn.BatchNorm1d,
15 | torch.nn.BatchNorm2d,
16 | torch.nn.BatchNorm3d,
17 | torch.nn.SyncBatchNorm,
18 | # NaiveSyncBatchNorm inherits from BatchNorm2d
19 | torch.nn.GroupNorm,
20 | torch.nn.InstanceNorm1d,
21 | torch.nn.InstanceNorm2d,
22 | torch.nn.InstanceNorm3d,
23 | torch.nn.LayerNorm,
24 | torch.nn.LocalResponseNorm,
25 | ]
26 |
27 | def register_norm_module(cls):
28 | NORM_MODULES.append(cls)
29 | return cls
30 |
31 | def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
32 | model_keys = sorted(model_state_dict.keys())
33 | ckpt_keys = sorted(ckpt_state_dict.keys())
34 | result_dicts = {}
35 | matched_log = []
36 | unmatched_log = []
37 | unloaded_log = []
38 | for model_key in model_keys:
39 | model_weight = model_state_dict[model_key]
40 | if model_key in ckpt_keys:
41 | ckpt_weight = ckpt_state_dict[model_key]
42 | if model_weight.shape == ckpt_weight.shape:
43 | result_dicts[model_key] = ckpt_weight
44 | ckpt_keys.pop(ckpt_keys.index(model_key))
45 | matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
46 | else:
47 | unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
48 | else:
49 | unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
50 |
51 | if is_main_process():
52 | for info in matched_log:
53 | logger.info(info)
54 | for info in unloaded_log:
55 | logger.warning(info)
56 | for key in ckpt_keys:
57 | logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
58 | for info in unmatched_log:
59 | logger.warning(info)
60 | return result_dicts
--------------------------------------------------------------------------------
/utils/prompt_engineering.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def get_prompt_templates():
5 | prompt_templates = [
6 | '{}.',
7 | 'a photo of a {}.',
8 | 'a bad photo of a {}.',
9 | 'a photo of many {}.',
10 | 'a sculpture of a {}.',
11 | 'a photo of the hard to see {}.',
12 | 'a low resolution photo of the {}.',
13 | 'a rendering of a {}.',
14 | 'graffiti of a {}.',
15 | 'a bad photo of the {}.',
16 | 'a cropped photo of the {}.',
17 | 'a tattoo of a {}.',
18 | 'the embroidered {}.',
19 | 'a photo of a hard to see {}.',
20 | 'a bright photo of a {}.',
21 | 'a photo of a clean {}.',
22 | 'a photo of a dirty {}.',
23 | 'a dark photo of the {}.',
24 | 'a drawing of a {}.',
25 | 'a photo of my {}.',
26 | 'the plastic {}.',
27 | 'a photo of the cool {}.',
28 | 'a close-up photo of a {}.',
29 | 'a black and white photo of the {}.',
30 | 'a painting of the {}.',
31 | 'a painting of a {}.',
32 | 'a pixelated photo of the {}.',
33 | 'a sculpture of the {}.',
34 | 'a bright photo of the {}.',
35 | 'a cropped photo of a {}.',
36 | 'a plastic {}.',
37 | 'a photo of the dirty {}.',
38 | 'a jpeg corrupted photo of a {}.',
39 | 'a blurry photo of the {}.',
40 | 'a photo of the {}.',
41 | 'a good photo of the {}.',
42 | 'a rendering of the {}.',
43 | 'a {} in a video game.',
44 | 'a photo of one {}.',
45 | 'a doodle of a {}.',
46 | 'a close-up photo of the {}.',
47 | 'the origami {}.',
48 | 'the {} in a video game.',
49 | 'a sketch of a {}.',
50 | 'a doodle of the {}.',
51 | 'a origami {}.',
52 | 'a low resolution photo of a {}.',
53 | 'the toy {}.',
54 | 'a rendition of the {}.',
55 | 'a photo of the clean {}.',
56 | 'a photo of a large {}.',
57 | 'a rendition of a {}.',
58 | 'a photo of a nice {}.',
59 | 'a photo of a weird {}.',
60 | 'a blurry photo of a {}.',
61 | 'a cartoon {}.',
62 | 'art of a {}.',
63 | 'a sketch of the {}.',
64 | 'a embroidered {}.',
65 | 'a pixelated photo of a {}.',
66 | 'itap of the {}.',
67 | 'a jpeg corrupted photo of the {}.',
68 | 'a good photo of a {}.',
69 | 'a plushie {}.',
70 | 'a photo of the nice {}.',
71 | 'a photo of the small {}.',
72 | 'a photo of the weird {}.',
73 | 'the cartoon {}.',
74 | 'art of the {}.',
75 | 'a drawing of the {}.',
76 | 'a photo of the large {}.',
77 | 'a black and white photo of a {}.',
78 | 'the plushie {}.',
79 | 'a dark photo of a {}.',
80 | 'itap of a {}.',
81 | 'graffiti of the {}.',
82 | 'a toy {}.',
83 | 'itap of my {}.',
84 | 'a photo of a cool {}.',
85 | 'a photo of a small {}.',
86 | 'a tattoo of the {}.',
87 | ]
88 | return prompt_templates
89 |
90 | def prompt_engineering(classnames, topk=1, suffix='.'):
91 | prompt_templates = get_prompt_templates()
92 | temp_idx = np.random.randint(min(len(prompt_templates), topk))
93 |
94 | if isinstance(classnames, list):
95 | classname = random.choice(classnames)
96 | else:
97 | classname = classnames
98 |
99 | return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))
--------------------------------------------------------------------------------