The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── assets
    ├── images
    │   ├── audio.png
    │   ├── click.png
    │   ├── click2mask.png
    │   ├── compare.jpg
    │   ├── compare_with_sam.jpg
    │   ├── emoj.png
    │   ├── emoj_scrib_draw.png
    │   ├── emoj_v1.jpg
    │   ├── emoj_v1_seg.png
    │   ├── fox.png
    │   ├── fox_v2.png
    │   ├── intro.png
    │   ├── method_xyz.png
    │   ├── minecraft.png
    │   ├── model.jpg
    │   ├── model.png
    │   ├── ref_seg.png
    │   ├── ref_seg_xyz.png
    │   ├── referring_video_visualize.png
    │   ├── spatial_relation.png
    │   ├── teaser.png
    │   ├── teaser_new.png
    │   ├── text.png
    │   ├── transformers_gh.png
    │   └── trees_text.png
    ├── readmes
    │   ├── DATASET.md
    │   ├── EVAL.md
    │   ├── INSTALL.md
    │   └── TRAIN.md
    ├── requirements
    │   ├── requirements.txt
    │   └── requirements_custom.txt
    └── scripts
    │   └── run_demo.sh
├── configs
    ├── seem
    │   ├── davitd3_unicl_lang_v1.yaml
    │   ├── davitd5_unicl_lang_v1.yaml
    │   ├── focall_unicl_lang_demo.yaml
    │   ├── focall_unicl_lang_v0.yaml
    │   ├── focall_unicl_lang_v1.yaml
    │   ├── focalt_unicl_lang_demo.yaml
    │   ├── focalt_unicl_lang_v0.yaml
    │   ├── focalt_unicl_lang_v1.yaml
    │   ├── samvitb_unicl_lang_v1.yaml
    │   └── samvitl_unicl_lang_v1.yaml
    └── xdecoder
    │   ├── davitd3_unicl_lang.yaml
    │   ├── davitd5_unicl_lang.yaml
    │   ├── focall_unicl_lang.yaml
    │   └── focalt_unicl_lang.yaml
├── datasets
    ├── __init__.py
    ├── build.py
    ├── dataset_mappers
    │   ├── __init__.py
    │   ├── bdd_semseg_dataset_mapper.py
    │   ├── coco_instance_new_baseline_dataset_mapper.py
    │   ├── coco_panoptic_interactive_dataset_mapper.py
    │   ├── coco_panoptic_new_baseline_dataset_mapper.py
    │   ├── imagenet_dataset_mapper.py
    │   ├── mask_former_instance_dataset_mapper.py
    │   ├── mask_former_panoptic_dataset_mapper.py
    │   ├── mask_former_semantic_dataset_mapper.py
    │   ├── pascalvoc_dataset_mapper_ix.py
    │   ├── refcoco_dataset_mapper.py
    │   ├── scannet_dataset_mapper.py
    │   ├── scannet_pano_dataset_mapper.py
    │   ├── sunrgbd_dataset_mapper.py
    │   └── vlp_dataset_mapper.py
    ├── evaluation
    │   ├── __init__.py
    │   ├── captioning_evaluation.py
    │   ├── classification_evaluation.py
    │   ├── grounding_evaluation.py
    │   ├── instance_evaluation.py
    │   ├── interactive_evaluation.py
    │   ├── panoptic_evaluation.py
    │   ├── retrieval_evaluation.py
    │   └── segmentation_evaluation.py
    ├── refer.py
    ├── registration
    │   ├── __init__.py
    │   ├── register_ade20k_full.py
    │   ├── register_ade20k_instance.py
    │   ├── register_ade20k_panoptic.py
    │   ├── register_bdd100k_panoseg.py
    │   ├── register_bdd100k_semseg.py
    │   ├── register_coco_lvis_panoptic_annos_caption_grounding.py
    │   ├── register_coco_panoptic_annos_caption.py
    │   ├── register_coco_panoptic_annos_caption_grounding.py
    │   ├── register_coco_panoptic_annos_semseg.py
    │   ├── register_coco_stuff_10k.py
    │   ├── register_imagenet_cls.py
    │   ├── register_pascalvoc_eval.py
    │   ├── register_refcoco_dataset.py
    │   ├── register_scannet_panoptic.py
    │   ├── register_scannet_semseg.py
    │   ├── register_sunrgbd_semseg.py
    │   └── register_vlp_datasets.py
    ├── semseg_loader.py
    ├── utils
    │   ├── refcoco2json.py
    │   └── refer.py
    └── visual_sampler
    │   ├── __init__.py
    │   ├── circle.py
    │   ├── mask_generators.py
    │   ├── point.py
    │   ├── polygon.py
    │   ├── sampler.py
    │   ├── scribble.py
    │   └── simpleclick_sampler.py
├── demo
    ├── __init__.py
    └── seem
    │   ├── __init__.py
    │   ├── app.py
    │   ├── examples
    │       ├── corgi1.webp
    │       ├── corgi2.jpg
    │       ├── fries1.png
    │       ├── fries2.png
    │       ├── minecraft1.jpg
    │       ├── placeholder.png
    │       ├── ref_vase.JPG
    │       ├── river1.png
    │       ├── river1.wav
    │       ├── river1_mask.png
    │       ├── river2.png
    │       ├── vasedeck.mp4
    │       ├── zebras1.jpg
    │       └── zebras2.jpg
    │   └── tasks
    │       ├── __init__.py
    │       └── interactive.py
├── entry.py
├── inference
    ├── __init__.py
    ├── images
    │   ├── animals.png
    │   ├── apples.jpg
    │   ├── coco
    │   │   ├── 000.jpg
    │   │   ├── 001.jpg
    │   │   ├── 002.jpg
    │   │   └── 003.jpg
    │   ├── fruit.jpg
    │   ├── landscape.jpg
    │   ├── mountain.jpeg
    │   ├── owls.jpeg
    │   ├── penguin.jpeg
    │   ├── region_retrieval.png
    │   ├── rose.webp
    │   ├── street.jpg
    │   └── teaser_new.png
    └── xdecoder
    │   ├── infer_captioning.py
    │   ├── infer_instseg.py
    │   ├── infer_panoseg.py
    │   ├── infer_refseg.py
    │   ├── infer_region_retrieval.py
    │   └── infer_semseg.py
├── modeling
    ├── BaseModel.py
    ├── __init__.py
    ├── architectures
    │   ├── __init__.py
    │   ├── build.py
    │   ├── seem_model_demo.py
    │   ├── seem_model_v0.py
    │   ├── seem_model_v1.py
    │   └── xdecoder_model.py
    ├── body
    │   ├── __init__.py
    │   ├── build.py
    │   └── xdecoder_head.py
    ├── interface
    │   ├── __init__.py
    │   ├── build.py
    │   ├── modules.py
    │   ├── prototype
    │   │   ├── __init__.py
    │   │   ├── attention_data_struct_seemdemo.py
    │   │   ├── attention_data_struct_seemv0.py
    │   │   └── attention_data_struct_seemv1.py
    │   ├── seem_demo.py
    │   ├── seem_v0.py
    │   ├── seem_v1.py
    │   └── xdecoder.py
    ├── language
    │   ├── LangEncoder
    │   │   ├── __init__.py
    │   │   ├── build.py
    │   │   └── transformer.py
    │   ├── __init__.py
    │   ├── build.py
    │   ├── loss.py
    │   ├── misc.py
    │   └── vlpencoder.py
    ├── modules
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── criterion.py
    │   ├── matcher.py
    │   ├── point_features.py
    │   ├── position_encoding.py
    │   └── postprocessing.py
    ├── utils
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── box_ops.py
    │   ├── config.py
    │   ├── interactive.py
    │   └── misc.py
    └── vision
    │   ├── backbone
    │       ├── __init__.py
    │       ├── backbone.py
    │       ├── build.py
    │       ├── common.py
    │       ├── davit.py
    │       ├── focal.py
    │       ├── focal_dw.py
    │       └── vit.py
    │   └── encoder
    │       ├── __init__.py
    │       ├── build.py
    │       ├── ops
    │           ├── functions
    │           │   ├── __init__.py
    │           │   └── ms_deform_attn_func.py
    │           ├── make.sh
    │           ├── modules
    │           │   ├── __init__.py
    │           │   └── ms_deform_attn.py
    │           ├── setup.py
    │           ├── src
    │           │   ├── cpu
    │           │   │   ├── ms_deform_attn_cpu.cpp
    │           │   │   └── ms_deform_attn_cpu.h
    │           │   ├── cuda
    │           │   │   ├── ms_deform_attn_cuda.cu
    │           │   │   ├── ms_deform_attn_cuda.h
    │           │   │   └── ms_deform_im2col_cuda.cuh
    │           │   ├── ms_deform_attn.h
    │           │   └── vision.cpp
    │           └── test.py
    │       ├── transformer_blocks.py
    │       ├── transformer_encoder_deform.py
    │       └── transformer_encoder_fpn.py
├── pipeline
    ├── XDecoderPipeline.py
    ├── __init__.py
    └── utils
    │   └── misc.py
├── pyproject.toml
├── trainer
    ├── __init__.py
    ├── default_trainer.py
    ├── distributed_trainer.py
    ├── utils
    │   ├── __init__.py
    │   ├── hook.py
    │   ├── misc.py
    │   ├── mpi_adapter.py
    │   └── serialization.py
    ├── utils_trainer.py
    └── xdecoder_trainer.py
└── utils
    ├── Config.py
    ├── __init__.py
    ├── arguments.py
    ├── constants.py
    ├── dataset.py
    ├── distributed.py
    ├── misc.py
    ├── model.py
    ├── prompt_engineering.py
    └── visualizer.py


/.gitignore:
--------------------------------------------------------------------------------
  1 | # IntelliJ project files
  2 | .idea
  3 | *.iml
  4 | out
  5 | gen
  6 | 
  7 | ### Vim template
  8 | [._]*.s[a-w][a-z]
  9 | [._]s[a-w][a-z]
 10 | *.un~
 11 | Session.vim
 12 | .netrwhist
 13 | *~
 14 | 
 15 | ### IPythonNotebook template
 16 | # Temporary data
 17 | .ipynb_checkpoints/
 18 | 
 19 | ### Python template
 20 | # Byte-compiled / optimized / DLL files
 21 | __pycache__/
 22 | *.py[cod]
 23 | *$py.class
 24 | 
 25 | # C extensions
 26 | *.so
 27 | 
 28 | # Distribution / packaging
 29 | .Python
 30 | env/
 31 | build/
 32 | develop-eggs/
 33 | dist/
 34 | downloads/
 35 | eggs/
 36 | .eggs/
 37 | #lib/
 38 | #lib64/
 39 | parts/
 40 | sdist/
 41 | var/
 42 | *.egg-info/
 43 | .installed.cfg
 44 | *.egg
 45 | 
 46 | # PyInstaller
 47 | #  Usually these files are written by a python script from a template
 48 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 49 | *.manifest
 50 | *.spec
 51 | 
 52 | # Installer logs
 53 | pip-log.txt
 54 | pip-delete-this-directory.txt
 55 | 
 56 | # Unit test / coverage reports
 57 | htmlcov/
 58 | .tox/
 59 | .coverage
 60 | .coverage.*
 61 | .cache
 62 | nosetests.xml
 63 | coverage.xml
 64 | *,cover
 65 | 
 66 | # Translations
 67 | *.mo
 68 | *.pot
 69 | 
 70 | # Django stuff:
 71 | *.log
 72 | 
 73 | # Sphinx documentation
 74 | docs/_build/
 75 | 
 76 | # PyBuilder
 77 | target/
 78 | 
 79 | *.ipynb
 80 | *.params
 81 | # *.json
 82 | .vscode/
 83 | *.code-workspace/
 84 | 
 85 | lib/pycocotools/_mask.c
 86 | lib/nms/cpu_nms.c
 87 | 
 88 | OUTPUT
 89 | OUTPUT/*
 90 | models/*
 91 | DATASET
 92 | DATASET/*
 93 | external/
 94 | amlt/
 95 | amlt/*
 96 | MODELS
 97 | MODELS/*
 98 | 
 99 | eval_seem.sh
100 | train_seem.sh
101 | train_seem_v0.sh
102 | 
103 | draws/
104 | plot/
105 | 
106 | Config/
107 | Config/*
108 | 
109 | *venv/*
110 | ./demo_code/*.pt
111 | ./demo_code/*.pth
112 | *.pt
113 | *.pth
114 | 


--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/__init__.py


--------------------------------------------------------------------------------
/assets/images/audio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/audio.png


--------------------------------------------------------------------------------
/assets/images/click.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/click.png


--------------------------------------------------------------------------------
/assets/images/click2mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/click2mask.png


--------------------------------------------------------------------------------
/assets/images/compare.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/compare.jpg


--------------------------------------------------------------------------------
/assets/images/compare_with_sam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/compare_with_sam.jpg


--------------------------------------------------------------------------------
/assets/images/emoj.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/emoj.png


--------------------------------------------------------------------------------
/assets/images/emoj_scrib_draw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/emoj_scrib_draw.png


--------------------------------------------------------------------------------
/assets/images/emoj_v1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/emoj_v1.jpg


--------------------------------------------------------------------------------
/assets/images/emoj_v1_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/emoj_v1_seg.png


--------------------------------------------------------------------------------
/assets/images/fox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/fox.png


--------------------------------------------------------------------------------
/assets/images/fox_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/fox_v2.png


--------------------------------------------------------------------------------
/assets/images/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/intro.png


--------------------------------------------------------------------------------
/assets/images/method_xyz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/method_xyz.png


--------------------------------------------------------------------------------
/assets/images/minecraft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/minecraft.png


--------------------------------------------------------------------------------
/assets/images/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/model.jpg


--------------------------------------------------------------------------------
/assets/images/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/model.png


--------------------------------------------------------------------------------
/assets/images/ref_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/ref_seg.png


--------------------------------------------------------------------------------
/assets/images/ref_seg_xyz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/ref_seg_xyz.png


--------------------------------------------------------------------------------
/assets/images/referring_video_visualize.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/referring_video_visualize.png


--------------------------------------------------------------------------------
/assets/images/spatial_relation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/spatial_relation.png


--------------------------------------------------------------------------------
/assets/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/teaser.png


--------------------------------------------------------------------------------
/assets/images/teaser_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/teaser_new.png


--------------------------------------------------------------------------------
/assets/images/text.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/text.png


--------------------------------------------------------------------------------
/assets/images/transformers_gh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/transformers_gh.png


--------------------------------------------------------------------------------
/assets/images/trees_text.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/assets/images/trees_text.png


--------------------------------------------------------------------------------
/assets/readmes/DATASET.md:
--------------------------------------------------------------------------------
  1 | # Preparing Dataset
  2 | 
  3 | :bangbang: The dataset preparation contains many details, welcome community contribution to fix any bug, Thanks!
  4 | 
  5 | Our dataloader follows [Detectron2](https://github.com/facebookresearch/detectron2) that contains: <br/>
  6 | (1) [A dataset registrator](datasets/registration) <br/>
  7 | (2) [A dataset mapper](datasets/dataset_mappers) <br/>
  8 | We modify the dataset registration and mapper for custom datasets.
  9 | 
 10 | ## Training Dataset
 11 | We assume all the datasets are stored under:
 12 | ```
 13 | .xdecoder_data
 14 | ```
 15 | 
 16 | ### COCO (SEEM & X-Decoder)
 17 | 
 18 | ```sh
 19 | # Prepare panoptic_train2017, panoptic_semseg_train2017 exactly the same as [Mask2Fomer](https://github.com/facebookresearch/Mask2Former/tree/main/datasets)
 20 | 
 21 | # (SEEM & X-Decoder) Download additional logistic and custom annotation files to .xdecoder_data/coco/annotations
 22 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/caption_class_similarity.pth
 23 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/captions_train2017_filtrefgumdval_filtvlp.json
 24 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/grounding_train2017_filtrefgumdval_filtvlp.json
 25 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/panoptic_train2017_filtrefgumdval_filtvlp.json
 26 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/refcocog_umd_val.json
 27 | wget https://github.com/peteanderson80/coco-caption/blob/master/annotations/captions_val2014.json
 28 | 
 29 | # (SEEM) Download LVIS annotations for mask preparation
 30 | wget https://huggingface.co/xdecoder/SEEM/resolve/main/coco_train2017_filtrefgumdval_lvis.json
 31 | ```
 32 | 
 33 | After dataset preparation, the dataset structure would be:
 34 | ```
 35 | .xdecoder_data
 36 | └── coco/
 37 |     ├── train2017/
 38 |     ├── val2017/
 39 |     ├── panoptic_train2017/
 40 |     ├── panoptic_semseg_train2017/
 41 |     ├── panoptic_val2017/
 42 |     ├── panoptic_semseg_val2017/
 43 |     └── annotations/
 44 |         ├── refcocog_umd_val.json
 45 |         ├── captions_val2014.json
 46 |         ├── panoptic_val2017.json
 47 |         ├── caption_class_similarity.pth
 48 |         ├── panoptic_train2017_filtrefgumdval_filtvlp.json
 49 |         └── grounding_train2017_filtrefgumdval_filtvlp.json
 50 | └── lvis/
 51 |     └── coco_train2017_filtrefgumdval_lvis.json
 52 | ```
 53 | 
 54 | #### 4M Image Text Pairs (X-Decoder)
 55 | We follow the exact data preparation for the image text pairs data with [ViLT](https://github.com/dandelin/ViLT/blob/master/DATA.md).
 56 | ```
 57 | # The pretrained arrow file are put under .xdecoder_data/pretrain_arrows_code224 with the following list of files.
 58 | ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] + ["code224_vg.arrow"] + [f"code224_sbu_{i}.arrow" for i in range(9)] + [f"code224_conceptual_caption_train_{i}.arrow" for i in range(31)]
 59 | # ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] are originated from ["filtcoco2017val_caption_karpathy_train.arrow", "filtcoco2017val_caption_karpathy_val.arrow", "filtcoco2017val_caption_karpathy_restval.arrow"] with deletion of coco val2017 overlapped images to avoid information leakage.
 60 | ```
 61 | 
 62 | To get quick started:
 63 | ```sh 
 64 | # Download coco karparthy test set (we hack the training data to be coco_caption_karpathy_test.arrow only for quick start in the codebase)
 65 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/coco_caption_karpathy_test.arrow
 66 | ```
 67 | 
 68 | After dataset preparation, the dataset structure would be:
 69 | ```
 70 | .xdecoder_data
 71 | └── pretrain_arrows_code224/
 72 |     ├── coco_caption_karpathy_test.arrow
 73 |     ├── *filtcoco2017val_caption_karpathy_train.arrow
 74 |     ├── ...
 75 |     ├── *code224_vg.arrow
 76 |     ├── *code224_sbu_0.arrow
 77 |     ├── ...
 78 |     ├── *code224_conceptual_caption_train_0.arrow
 79 |     └── ...
 80 | * Those datasets are optional for debugging the pipeline. ! NEED to add back when you are training the model.
 81 | ```
 82 | 
 83 | ***NOTE:***
 84 | 
 85 | <img src="https://user-images.githubusercontent.com/11957155/226159078-7f817452-76f8-44f4-af7a-9f13f3e02554.png" width="500">
 86 | There are overlap between COCO2017, COCO-Karpathy and REF-COCO dataset, and ref-coco is all overlapped with the COCO2017 training data, we have exclude the refcocog-umd validation, coco-karpathy test split during training.
 87 | 
 88 | ## Evaluation Dataset
 89 | 
 90 | ### RefCOCO (SEEM & X-Decoder)
 91 | Please refer to COCO Preparation on [line](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/blob/v1.0/assets/readmes/DATASET.md#coco-seem--x-decoder).
 92 | 
 93 | ### ADE20K, Cityscapes (X-Decoder)
 94 | Please Refer to [Mask2Former](https://github.com/facebookresearch/Mask2Former/tree/main/datasets).
 95 | 
 96 | ### BDD100K (X-Decoder)
 97 | Please download the 10k split of BDD100k at https://doc.bdd100k.com/download.html#id1
 98 | 
 99 | ### PascalVOC and all other interactive evaluation datasets (SEEM)
100 | Please follow the instruction on [RITM](https://github.com/SamsungLabs/ritm_interactive_segmentation)
101 | 
102 | After dataset preparation, the dataset structure would be:
103 | ```
104 | .xdecoder_data
105 | └── PascalVOC/
106 |     ├── Annotations/
107 |     ├── ImageSets
108 |     ├── JPEGImages/
109 |     ├── SegmentationClass/
110 |     └── SegmentationObject/
111 | ```
112 | 
113 | 


--------------------------------------------------------------------------------
/assets/readmes/INSTALL.md:
--------------------------------------------------------------------------------
 1 | # Installation Guide
 2 | 
 3 | **General Environment**
 4 | * Linux System
 5 | * CUDA enabled GPU with Memory > 8GB (Evaluation)
 6 | * CUDA enabled GPU with Memory > 12GB (Training)
 7 | 
 8 | **Installation**
 9 | 
10 | ```sh
11 | # Python Package Installation
12 | pip install -r assets/requirements/requirements.txt
13 | pip install -r assets/requirements/requirements_custom.txt
14 | 
15 | # Customer Operator [only need training deformable vision encoder]
16 | cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../
17 | 
18 | # System Package [only need for demo in SEEM]
19 | sudo apt update
20 | sudo apt install ffmpeg
21 | ```
22 | 
23 | **Dataset Preparation**
24 | 
25 | Please refer to [DATASET.md](assets/readmes/DATASET.md).
26 | 
27 | **Evaluation Tool**
28 | ```sh
29 | # save coco_caption.zip to .xdecoder_data
30 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/coco_caption.zip
31 | unzip coco_caption.zip
32 | ```
33 | 
34 | **Environment Variables**
35 | ```sh
36 | export DETECTRON2_DATASETS=/pth/to/xdecoder_data
37 | export DATASET=/pth/to/xdecoder_data
38 | export DATASET2=/pth/to/xdecoder_data
39 | export VLDATASET=/pth/to/xdecoder_data
40 | export PATH=$PATH:/pth/to/xdecoder_data/coco_caption/jre1.8.0_321/bin
41 | export PYTHONPATH=$PYTHONPATH:/pth/to/xdecoder_data/coco_caption
42 | ```
43 | 
44 | **Pretrained Checkpoint**
45 | 
46 | X-Decoder:
47 | ```sh
48 | # Focal-T UniCL
49 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focalt_in21k_yfcc_gcc_xdecoder_unicl.pt
50 | 
51 | # Focal-L UniCL
52 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focall_vision_focalb_lang_unicl.pt
53 | ```
54 | 
55 | SEEM:
56 | ```
57 | # Focal-T X-Decoder
58 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_last.pt
59 | 
60 | # Focal-L X-Decoder
61 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focall_last_oq101.pt
62 | 
63 | # Focal-B UniCL Language
64 | wget https://huggingface.co/xdecoder/X-Decoder/resolve/main/focalb_lang_unicl.pt
65 | 
66 | # ViT-B SAM
67 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
68 | 
69 | # ViT-L SAM
70 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
71 | 
72 | ```
73 | 
74 | 


--------------------------------------------------------------------------------
/assets/requirements/requirements.txt:
--------------------------------------------------------------------------------
 1 | torch==2.1.0
 2 | torchvision==0.16.0
 3 | pillow==9.4.0
 4 | opencv-python==4.8.1.78
 5 | pyyaml==6.0.1
 6 | json_tricks==3.17.3
 7 | yacs==0.1.8
 8 | scikit-learn==1.3.1
 9 | pandas==2.0.3
10 | timm==0.4.12
11 | numpy==1.23.1
12 | einops==0.7.0
13 | fvcore==0.1.5.post20221221
14 | transformers==4.34.0
15 | sentencepiece==0.1.99
16 | ftfy==6.1.1
17 | regex==2023.10.3
18 | nltk==3.8.1
19 | mpi4py==3.1.5
20 | vision-datasets==0.2.2
21 | cython==3.0.2
22 | pycocotools==2.0.7
23 | diffdist==0.1
24 | pyarrow==13.0.0
25 | cityscapesscripts==2.2.2
26 | shapely==1.8.0
27 | scikit-image==0.21.0
28 | mup==1.0.0
29 | accelerate==0.23.0
30 | kornia==0.7.0
31 | deepspeed==0.10.3
32 | wandb==0.15.12
33 | infinibatch==0.1.1
34 | gradio==3.42.0


--------------------------------------------------------------------------------
/assets/requirements/requirements_custom.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/arogozhnikov/einops.git
2 | git+https://github.com/MaureenZOU/detectron2-xyz.git
3 | git+https://github.com/openai/whisper.git


--------------------------------------------------------------------------------
/assets/scripts/run_demo.sh:
--------------------------------------------------------------------------------
1 | sudo apt update
2 | sudo apt install ffmpeg
3 | pip install -r assets/requirements/requirements.txt
4 | pip install -r assets/requirements/requirements_custom.txt
5 | python demo/seem/app.py


--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import registration
2 | from .build import build_train_dataloader, build_eval_dataloader, build_evaluator


--------------------------------------------------------------------------------
/datasets/dataset_mappers/__init__.py:
--------------------------------------------------------------------------------
 1 | from .coco_panoptic_interactive_dataset_mapper import COCOPanopticInteractiveDatasetMapper
 2 | from .coco_instance_new_baseline_dataset_mapper import COCOInstanceNewBaselineDatasetMapper
 3 | from .coco_panoptic_new_baseline_dataset_mapper import COCOPanopticNewBaselineDatasetMapper
 4 | from .mask_former_instance_dataset_mapper import MaskFormerInstanceDatasetMapper
 5 | from .mask_former_panoptic_dataset_mapper import MaskFormerPanopticDatasetMapper
 6 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
 7 | from .imagenet_dataset_mapper import ImageNetDatasetMapper
 8 | from .vlp_dataset_mapper import VLPreDatasetMapper
 9 | from .sunrgbd_dataset_mapper import SunRGBDSegDatasetMapper
10 | from .scannet_dataset_mapper import ScanNetSegDatasetMapper
11 | from .bdd_semseg_dataset_mapper import BDDSemDatasetMapper
12 | from .scannet_pano_dataset_mapper import ScanNetPanoDatasetMapper
13 | from .refcoco_dataset_mapper import RefCOCODatasetMapper
14 | from .pascalvoc_dataset_mapper_ix import PascalVOCSegDatasetMapperIX


--------------------------------------------------------------------------------
/datasets/dataset_mappers/bdd_semseg_dataset_mapper.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | # Copyright (c) Facebook, Inc. and its affiliates.
  8 | import copy
  9 | 
 10 | import scipy.io
 11 | import numpy as np
 12 | import torch
 13 | from PIL import Image
 14 | 
 15 | from torchvision import transforms
 16 | from modeling.utils import configurable
 17 | 
 18 | __all__ = ["BDDSemDatasetMapper"]
 19 | 
 20 | 
 21 | # This is specifically designed for the COCO dataset.
 22 | class BDDSemDatasetMapper:
 23 |     """
 24 |     A callable which takes a dataset dict in Detectron2 Dataset format,
 25 |     and map it into a format used by MaskFormer.
 26 | 
 27 |     This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
 28 | 
 29 |     The callable currently does the following:
 30 | 
 31 |     1. Read the image from "file_name"
 32 |     2. Applies geometric transforms to the image and annotation
 33 |     3. Find and applies suitable cropping to the image and annotation
 34 |     4. Prepare image and annotation to Tensors
 35 |     """
 36 | 
 37 |     @configurable
 38 |     def __init__(
 39 |         self,
 40 |         is_train=True,
 41 |         min_size_test=None,
 42 |         max_size_test=None,
 43 |         mean=None,
 44 |         std=None,
 45 |     ):
 46 |         """
 47 |         NOTE: this interface is experimental.
 48 |         Args:
 49 |             is_train: for training or inference
 50 |             augmentations: a list of augmentations or deterministic transforms to apply
 51 |             tfm_gens: data augmentation
 52 |             image_format: an image format supported by :func:`detection_utils.read_image`.
 53 |         """
 54 |         self.is_train = is_train
 55 |         self.min_size_test = min_size_test
 56 |         self.max_size_test = max_size_test
 57 |         self.pixel_mean = torch.tensor(mean)[:,None,None]
 58 |         self.pixel_std = torch.tensor(std)[:,None,None]
 59 | 
 60 |         t = []
 61 |         t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
 62 |         self.transform = transforms.Compose(t)
 63 |     
 64 |     @classmethod
 65 |     def from_config(cls, cfg, is_train=True):
 66 |         ret = {
 67 |             "is_train": is_train,
 68 |             "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
 69 |             "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
 70 |             "mean": cfg['INPUT']['PIXEL_MEAN'],
 71 |             "std": cfg['INPUT']['PIXEL_STD'],
 72 |         }
 73 |         return ret
 74 |     
 75 |     def read_semseg(self, file_name):
 76 |         if '.png' in file_name:
 77 |             semseg = np.asarray(Image.open(file_name))
 78 |         elif '.mat' in file_name:
 79 |             semseg = scipy.io.loadmat(file_name)['LabelMap']
 80 |         return semseg
 81 | 
 82 |     def __call__(self, dataset_dict):
 83 |         """
 84 |         Args:
 85 |             dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
 86 | 
 87 |         Returns:
 88 |             dict: a format that builtin models in detectron2 accept
 89 |         """
 90 |         dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
 91 |         file_name = dataset_dict['file_name']
 92 |         semseg_name = dataset_dict['sem_seg_file_name']
 93 |         image = Image.open(file_name).convert('RGB')
 94 | 
 95 |         dataset_dict['width'] = image.size[0]
 96 |         dataset_dict['height'] = image.size[1]
 97 | 
 98 |         if self.is_train == False:
 99 |             image = self.transform(image)
100 |             image = torch.from_numpy(np.asarray(image).copy())
101 |             image = image.permute(2,0,1)
102 |             
103 |         semseg = self.read_semseg(semseg_name)
104 |         semseg = torch.from_numpy(semseg.astype(np.int32))
105 |         dataset_dict['image'] = image
106 |         dataset_dict['semseg'] = semseg
107 |         return dataset_dict


--------------------------------------------------------------------------------
/datasets/dataset_mappers/imagenet_dataset_mapper.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | # Copyright (c) Facebook, Inc. and its affiliates.
 8 | import copy
 9 | from PIL import Image
10 | # import logging
11 | 
12 | import cv2
13 | import numpy as np
14 | 
15 | import torch
16 | from torchvision import transforms
17 | 
18 | from modeling.utils import configurable
19 | 
20 | __all__ = ["ImageNetDatasetMapper"]
21 | 
22 | 
23 | # This is specifically designed for the COCO dataset.
24 | class ImageNetDatasetMapper:
25 |     """
26 |     A callable which takes a dataset dict in Detectron2 Dataset format,
27 |     and map it into a format used by MaskFormer.
28 | 
29 |     This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
30 | 
31 |     The callable currently does the following:
32 | 
33 |     1. Read the image from "file_name"
34 |     2. Applies geometric transforms to the image and annotation
35 |     3. Find and applies suitable cropping to the image and annotation
36 |     4. Prepare image and annotation to Tensors
37 |     """
38 | 
39 |     @configurable
40 |     def __init__(
41 |         self,
42 |         is_train=True,
43 |         size_train=None,
44 |         size_test=None,
45 |         size_crop=None,
46 |     ):
47 |         """
48 |         NOTE: this interface is experimental.
49 |         Args:
50 |             is_train: for training or inference
51 |             augmentations: a list of augmentations or deterministic transforms to apply
52 |             tfm_gens: data augmentation
53 |             image_format: an image format supported by :func:`detection_utils.read_image`.
54 |         """
55 |         self.is_train = is_train
56 |         self.size_train = size_train
57 |         self.size_test = size_test
58 |         self.size_crop = size_crop
59 | 
60 |         t = []
61 |         t.append(transforms.Resize(size_crop, interpolation=Image.BICUBIC))
62 |         t.append(transforms.CenterCrop(size_test))
63 |         self.transform = transforms.Compose(t)
64 |         
65 |     @classmethod
66 |     def from_config(cls, cfg, is_train=True):
67 |         ret = {
68 |             "is_train": is_train,
69 |             "size_train": cfg['INPUT']['SIZE_TRAIN'],
70 |             "size_test": cfg['INPUT']['SIZE_TEST'],
71 |             "size_crop": cfg['INPUT']['SIZE_CROP']
72 |         }
73 |         return ret
74 | 
75 |     def __call__(self, dataset_dict):
76 |         """
77 |         Args:
78 |             dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
79 | 
80 |         Returns:
81 |             dict: a format that builtin models in detectron2 accept
82 |         """
83 |         dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
84 |         file_name = dataset_dict['file_name']
85 |         image = Image.open(file_name).convert('RGB')
86 | 
87 |         if self.is_train == False:
88 |             image = self.transform(image)
89 |             image = torch.from_numpy(np.asarray(image).copy())            
90 |             image = image.permute(2,0,1)
91 | 
92 |         dataset_dict['image'] = image
93 |         dataset_dict['height'] = image.shape[1]
94 |         dataset_dict['width'] = image.shape[2]
95 |         return dataset_dict


--------------------------------------------------------------------------------
/datasets/dataset_mappers/scannet_dataset_mapper.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | # Copyright (c) Facebook, Inc. and its affiliates.
  8 | import copy
  9 | 
 10 | import scipy.io
 11 | import numpy as np
 12 | import torch
 13 | from PIL import Image
 14 | 
 15 | from torchvision import transforms
 16 | from modeling.utils import configurable
 17 | 
 18 | __all__ = ["ScanNetSegDatasetMapper"]
 19 | 
 20 | 
 21 | # This is specifically designed for the COCO dataset.
 22 | class ScanNetSegDatasetMapper:
 23 |     """
 24 |     A callable which takes a dataset dict in Detectron2 Dataset format,
 25 |     and map it into a format used by MaskFormer.
 26 | 
 27 |     This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
 28 | 
 29 |     The callable currently does the following:
 30 | 
 31 |     1. Read the image from "file_name"
 32 |     2. Applies geometric transforms to the image and annotation
 33 |     3. Find and applies suitable cropping to the image and annotation
 34 |     4. Prepare image and annotation to Tensors
 35 |     """
 36 | 
 37 |     @configurable
 38 |     def __init__(
 39 |         self,
 40 |         is_train=True,
 41 |         min_size_test=None,
 42 |         max_size_test=None,
 43 |         mean=None,
 44 |         std=None,
 45 |     ):
 46 |         """
 47 |         NOTE: this interface is experimental.
 48 |         Args:
 49 |             is_train: for training or inference
 50 |             augmentations: a list of augmentations or deterministic transforms to apply
 51 |             tfm_gens: data augmentation
 52 |             image_format: an image format supported by :func:`detection_utils.read_image`.
 53 |         """
 54 |         self.is_train = is_train
 55 |         self.min_size_test = min_size_test
 56 |         self.max_size_test = max_size_test
 57 |         self.pixel_mean = torch.tensor(mean)[:,None,None]
 58 |         self.pixel_std = torch.tensor(std)[:,None,None]
 59 | 
 60 |         t = []
 61 |         t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
 62 |         self.transform = transforms.Compose(t)
 63 |     
 64 |     @classmethod
 65 |     def from_config(cls, cfg, is_train=True):
 66 |         ret = {
 67 |             "is_train": is_train,
 68 |             "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
 69 |             "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
 70 |             "mean": cfg['INPUT']['PIXEL_MEAN'],
 71 |             "std": cfg['INPUT']['PIXEL_STD'],
 72 |         }
 73 |         return ret
 74 |     
 75 |     def read_semseg(self, file_name):
 76 |         if '.png' in file_name:
 77 |             semseg = np.asarray(Image.open(file_name))
 78 |         elif '.mat' in file_name:
 79 |             semseg = scipy.io.loadmat(file_name)['LabelMap']
 80 |         return semseg
 81 | 
 82 |     def __call__(self, dataset_dict):
 83 |         """
 84 |         Args:
 85 |             dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
 86 | 
 87 |         Returns:
 88 |             dict: a format that builtin models in detectron2 accept
 89 |         """
 90 |         dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
 91 |         file_name = dataset_dict['file_name']
 92 |         semseg_name = dataset_dict['sem_seg_file_name']
 93 |         image = Image.open(file_name).convert('RGB')
 94 |         
 95 |         dataset_dict['width'] = image.size[0]
 96 |         dataset_dict['height'] = image.size[1]
 97 | 
 98 |         if self.is_train == False:
 99 |             image = self.transform(image)
100 |             image = torch.from_numpy(np.asarray(image).copy())
101 |             image = image.permute(2,0,1)
102 |             
103 |         semseg = self.read_semseg(semseg_name)
104 |         semseg = torch.from_numpy(semseg.astype(np.int32))
105 |         dataset_dict['image'] = image
106 |         dataset_dict['semseg'] = semseg
107 |         return dataset_dict


--------------------------------------------------------------------------------
/datasets/dataset_mappers/scannet_pano_dataset_mapper.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | # Copyright (c) Facebook, Inc. and its affiliates.
 8 | import copy
 9 | 
10 | import scipy.io
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 | 
15 | from torchvision import transforms
16 | from modeling.utils import configurable
17 | 
18 | __all__ = ["ScanNetPanoDatasetMapper"]
19 | 
20 | 
21 | # This is specifically designed for the COCO dataset.
22 | class ScanNetPanoDatasetMapper:
23 |     """
24 |     A callable which takes a dataset dict in Detectron2 Dataset format,
25 |     and map it into a format used by MaskFormer.
26 | 
27 |     This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
28 | 
29 |     The callable currently does the following:
30 | 
31 |     1. Read the image from "file_name"
32 |     2. Applies geometric transforms to the image and annotation
33 |     3. Find and applies suitable cropping to the image and annotation
34 |     4. Prepare image and annotation to Tensors
35 |     """
36 | 
37 |     @configurable
38 |     def __init__(
39 |         self,
40 |         is_train=True,
41 |         min_size_test=None,
42 |         max_size_test=None,
43 |         mean=None,
44 |         std=None,
45 |     ):
46 |         """
47 |         NOTE: this interface is experimental.
48 |         Args:
49 |             is_train: for training or inference
50 |             augmentations: a list of augmentations or deterministic transforms to apply
51 |             tfm_gens: data augmentation
52 |             image_format: an image format supported by :func:`detection_utils.read_image`.
53 |         """
54 |         self.is_train = is_train
55 |         self.min_size_test = min_size_test
56 |         self.max_size_test = max_size_test
57 |         self.pixel_mean = torch.tensor(mean)[:,None,None]
58 |         self.pixel_std = torch.tensor(std)[:,None,None]
59 | 
60 |         t = []
61 |         t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
62 |         self.transform = transforms.Compose(t)
63 |     
64 |     @classmethod
65 |     def from_config(cls, cfg, is_train=True):
66 |         ret = {
67 |             "is_train": is_train,
68 |             "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
69 |             "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
70 |             "mean": cfg['INPUT']['PIXEL_MEAN'],
71 |             "std": cfg['INPUT']['PIXEL_STD'],
72 |         }
73 |         return ret
74 |     
75 |     def __call__(self, dataset_dict):
76 |         """
77 |         Args:
78 |             dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
79 | 
80 |         Returns:
81 |             dict: a format that builtin models in detectron2 accept
82 |         """
83 |         dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
84 |         file_name = dataset_dict['file_name']
85 |         image = Image.open(file_name).convert('RGB')
86 | 
87 |         dataset_dict['file_name'] = '_'.join(file_name.split('/')[-3:]) # HACK for /tmp file storage on predictions.
88 |         dataset_dict['width'] = image.size[0]
89 |         dataset_dict['height'] = image.size[1]
90 | 
91 |         image = self.transform(image)
92 |         image = torch.from_numpy(np.asarray(image).copy())
93 |         image = image.permute(2,0,1)
94 |         dataset_dict['image'] = image
95 |         return dataset_dict


--------------------------------------------------------------------------------
/datasets/dataset_mappers/sunrgbd_dataset_mapper.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | # Copyright (c) Facebook, Inc. and its affiliates.
  8 | import copy
  9 | 
 10 | import scipy.io
 11 | import numpy as np
 12 | import torch
 13 | from PIL import Image
 14 | 
 15 | from torchvision import transforms
 16 | from modeling.utils import configurable
 17 | 
 18 | __all__ = ["SunRGBDSegDatasetMapper"]
 19 | 
 20 | 
 21 | # This is specifically designed for the COCO dataset.
 22 | class SunRGBDSegDatasetMapper:
 23 |     """
 24 |     A callable which takes a dataset dict in Detectron2 Dataset format,
 25 |     and map it into a format used by MaskFormer.
 26 | 
 27 |     This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
 28 | 
 29 |     The callable currently does the following:
 30 | 
 31 |     1. Read the image from "file_name"
 32 |     2. Applies geometric transforms to the image and annotation
 33 |     3. Find and applies suitable cropping to the image and annotation
 34 |     4. Prepare image and annotation to Tensors
 35 |     """
 36 | 
 37 |     @configurable
 38 |     def __init__(
 39 |         self,
 40 |         is_train=True,
 41 |         min_size_test=None,
 42 |         max_size_test=None,
 43 |         mean=None,
 44 |         std=None,
 45 |     ):
 46 |         """
 47 |         NOTE: this interface is experimental.
 48 |         Args:
 49 |             is_train: for training or inference
 50 |             augmentations: a list of augmentations or deterministic transforms to apply
 51 |             tfm_gens: data augmentation
 52 |             image_format: an image format supported by :func:`detection_utils.read_image`.
 53 |         """
 54 |         self.is_train = is_train
 55 |         self.min_size_test = min_size_test
 56 |         self.max_size_test = max_size_test
 57 |         self.pixel_mean = torch.tensor(mean)[:,None,None]
 58 |         self.pixel_std = torch.tensor(std)[:,None,None]
 59 | 
 60 |         t = []
 61 |         t.append(transforms.Resize(self.min_size_test, interpolation=Image.BICUBIC))
 62 |         self.transform = transforms.Compose(t)
 63 |     
 64 |     @classmethod
 65 |     def from_config(cls, cfg, is_train=True):
 66 |         ret = {
 67 |             "is_train": is_train,
 68 |             "min_size_test": cfg['INPUT']['MIN_SIZE_TEST'],
 69 |             "max_size_test": cfg['INPUT']['MAX_SIZE_TEST'],
 70 |             "mean": cfg['INPUT']['PIXEL_MEAN'],
 71 |             "std": cfg['INPUT']['PIXEL_STD'],
 72 |         }
 73 |         return ret
 74 |     
 75 |     def read_semseg(self, file_name):
 76 |         if '.png' in file_name:
 77 |             semseg = np.asarray(Image.open(file_name))
 78 |         elif '.mat' in file_name:
 79 |             semseg = scipy.io.loadmat(file_name)['LabelMap']
 80 |         return semseg
 81 | 
 82 |     def __call__(self, dataset_dict):
 83 |         """
 84 |         Args:
 85 |             dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
 86 | 
 87 |         Returns:
 88 |             dict: a format that builtin models in detectron2 accept
 89 |         """
 90 |         dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
 91 |         file_name = dataset_dict['file_name']
 92 |         semseg_name = dataset_dict['sem_seg_file_name']
 93 |         image = Image.open(file_name).convert('RGB')
 94 | 
 95 |         dataset_dict['width'] = image.size[0]
 96 |         dataset_dict['height'] = image.size[1]
 97 | 
 98 |         if self.is_train == False:
 99 |             image = self.transform(image)
100 |             image = torch.from_numpy(np.asarray(image).copy())
101 |             image = image.permute(2,0,1)
102 |             
103 |         semseg = self.read_semseg(semseg_name)
104 |         semseg = torch.from_numpy(semseg.astype(np.int32))
105 |         dataset_dict['image'] = image
106 |         dataset_dict['semseg'] = semseg
107 |         return dataset_dict


--------------------------------------------------------------------------------
/datasets/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .instance_evaluation import *
2 | from .classification_evaluation import *
3 | from .segmentation_evaluation import *
4 | from .retrieval_evaluation import *
5 | from .captioning_evaluation import *
6 | from .panoptic_evaluation import *
7 | from .grounding_evaluation import *
8 | from .interactive_evaluation import *


--------------------------------------------------------------------------------
/datasets/evaluation/classification_evaluation.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates.
 2 | # --------------------------------------------------------
 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 4 | # Copyright (c) 2022 Microsoft
 5 | # Licensed under The MIT License [see LICENSE for details]
 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 7 | # --------------------------------------------------------
 8 | 
 9 | import torch
10 | import logging
11 | 
12 | from detectron2.evaluation.evaluator import DatasetEvaluator
13 | 
14 | from utils.misc import AverageMeter
15 | from utils.distributed import get_world_size
16 | 
17 | 
18 | @torch.no_grad()
19 | def accuracy(output, target, topk=(1,)):
20 |     """Computes the precision@k for the specified values of k"""
21 |     if isinstance(output, list):
22 |         output = output[-1]
23 | 
24 |     n_classes = output.size()[1]
25 |     maxk = min(max(topk), n_classes)
26 |     batch_size = target.size(0)
27 |     _, pred = output.topk(maxk, 1, True, True)
28 |     pred = pred.t()
29 |     correct = pred.eq(target.reshape(1, -1).expand_as(pred))
30 | 
31 |     res = []
32 |     for k in topk:
33 |         correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
34 |         res.append(correct_k.mul_(100.0 / batch_size).item())
35 |     return res
36 | 
37 | class ClassificationEvaluator(DatasetEvaluator):
38 |     def __init__(self, *args):
39 |         self.top1 = AverageMeter()
40 |         self.top5 = AverageMeter()
41 |         self._logger = logging.getLogger(__name__)
42 | 
43 |     def reset(self):
44 |         self.top1.reset()
45 |         self.top5.reset()
46 | 
47 |     def process(self, inputs, outputs):
48 |         logits = torch.stack([o['pred_class'] for o in outputs])
49 |         y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
50 |         prec1, prec5 = accuracy(logits, y, (1, 5))
51 |         self.top1.update(prec1, y.size(0))
52 |         self.top5.update(prec5, y.size(0))
53 | 
54 |     def evaluate(self):
55 |         if get_world_size() > 1:
56 |             tmp_tensor = torch.tensor(
57 |                 [self.top1.sum, self.top5.sum, self.top1.count],
58 |                 device=torch.cuda.current_device()
59 |             )
60 |             torch.distributed.all_reduce(
61 |                 tmp_tensor, torch.distributed.ReduceOp.SUM
62 |             )
63 |             top1_sum, top5_sum, count = tmp_tensor.tolist()
64 |         else:
65 |             top1_sum = self.top1.sum
66 |             top5_sum = self.top5.sum
67 |             count = self.top1.count
68 | 
69 |         results = {}
70 |         scores = {
71 |             'top1': top1_sum / count,
72 |             "top5": top5_sum / count
73 |         }
74 |         results['class'] = scores
75 |         self._logger.info(results)
76 |         return results
77 | 


--------------------------------------------------------------------------------
/datasets/evaluation/grounding_evaluation.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | import logging
  8 | import torch
  9 | from torchvision.ops import box_iou
 10 | 
 11 | from detectron2.structures import BoxMode
 12 | from detectron2.data import MetadataCatalog
 13 | from detectron2.utils.comm import all_gather, is_main_process, synchronize
 14 | from detectron2.evaluation.evaluator import DatasetEvaluator
 15 | 
 16 | 
 17 | class GroundingEvaluator(DatasetEvaluator):
 18 |     """
 19 |     Evaluate grounding segmentation metrics.
 20 |     """
 21 | 
 22 |     def __init__(
 23 |         self,
 24 |         dataset_name,
 25 |         compute_box=False,
 26 |         distributed=True,
 27 |     ):
 28 |         self._logger = logging.getLogger(__name__)
 29 |         self._dataset_name = dataset_name
 30 |         self._distributed = distributed
 31 |         self._cpu_device = torch.device("cpu")
 32 |         self._compute_box = compute_box
 33 |         meta = MetadataCatalog.get(dataset_name)
 34 | 
 35 |     def reset(self):
 36 |         self.cum_I = 0
 37 |         self.cum_U = 0
 38 |         self.mIoU = 0
 39 |         self.eval_seg_iou_list = [.5, .6, .7, .8, .9]
 40 |         self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
 41 |         self.seg_total = 0
 42 |         if self._compute_box:
 43 |             self.mIoU_box = 0
 44 |             self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
 45 | 
 46 |     @staticmethod
 47 |     def computeIoU(pred_seg, gd_seg):
 48 |         I = (pred_seg & gd_seg)
 49 |         U = (pred_seg | gd_seg)
 50 |         return I, U
 51 | 
 52 |     def process(self, inputs, outputs):
 53 |         for input, output in zip(inputs, outputs):
 54 |             pred = output['grounding_mask'].sigmoid() > 0.5
 55 |             gt = input['groundings']['masks'].bool()
 56 |             bsi = len(pred)
 57 |             I, U = self.computeIoU(pred, gt)
 58 |             self.cum_I += I.sum().cpu()
 59 |             self.cum_U += U.sum().cpu()
 60 |             IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6)
 61 |             self.mIoU += IoU.sum().cpu()
 62 | 
 63 |             if self._compute_box:
 64 |                 pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
 65 |                 gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu()
 66 |                 IoU_box = box_iou(pred_box, gt_box).diagonal()
 67 |                 self.mIoU_box += IoU_box.sum()
 68 | 
 69 |             for idx in range(len(self.eval_seg_iou_list)):
 70 |                 eval_seg_iou = self.eval_seg_iou_list[idx]
 71 |                 self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu()
 72 |                 if self._compute_box:
 73 |                     self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu()
 74 |             self.seg_total += bsi
 75 | 
 76 |     def evaluate(self):
 77 |         if self._distributed:
 78 |             synchronize()
 79 |             self.cum_I = torch.stack(all_gather(self.cum_I)).sum()
 80 |             self.cum_U = torch.stack(all_gather(self.cum_U)).sum()
 81 |             self.mIoU = torch.stack(all_gather(self.mIoU)).sum()
 82 |             self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0)
 83 |             self.seg_total = sum(all_gather(self.seg_total))
 84 | 
 85 |             if self._compute_box:
 86 |                 self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum()
 87 |                 self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0)
 88 |             if not is_main_process():
 89 |                 return
 90 | 
 91 |         results = {}
 92 |         for idx in range(len(self.eval_seg_iou_list)):
 93 |             result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx])
 94 |             results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item()
 95 |         results['cIoU'] = (self.cum_I*100./self.cum_U).item()
 96 |         results['mIoU'] = (self.mIoU*100./self.seg_total).item()
 97 | 
 98 |         if self._compute_box:
 99 |             for idx in range(len(self.eval_seg_iou_list)):
100 |                 result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx])
101 |                 results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item()
102 |             results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item()
103 | 
104 |         self._logger.info(results)
105 |         return {'grounding': results}


--------------------------------------------------------------------------------
/datasets/evaluation/instance_evaluation.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Facebook, Inc. and its affiliates.
  2 | import contextlib
  3 | import copy
  4 | import io
  5 | import itertools
  6 | import json
  7 | import logging
  8 | import numpy as np
  9 | import os
 10 | import pickle
 11 | from collections import OrderedDict
 12 | import pycocotools.mask as mask_util
 13 | import torch
 14 | from pycocotools.coco import COCO
 15 | from pycocotools.cocoeval import COCOeval
 16 | from tabulate import tabulate
 17 | 
 18 | import detectron2.utils.comm as comm
 19 | from detectron2.config import CfgNode
 20 | from detectron2.data import MetadataCatalog
 21 | from detectron2.data.datasets.coco import convert_to_coco_json
 22 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
 23 | from detectron2.evaluation.fast_eval_api import COCOeval_opt
 24 | from detectron2.structures import Boxes, BoxMode, pairwise_iou
 25 | from detectron2.utils.file_io import PathManager
 26 | from detectron2.utils.logger import create_small_table
 27 | 
 28 | 
 29 | # modified from COCOEvaluator for instance segmetnat
 30 | class InstanceSegEvaluator(COCOEvaluator):
 31 |     """
 32 |     Evaluate AR for object proposals, AP for instance detection/segmentation, AP
 33 |     for keypoint detection outputs using COCO's metrics.
 34 |     See http://cocodataset.org/#detection-eval and
 35 |     http://cocodataset.org/#keypoints-eval to understand its metrics.
 36 |     The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
 37 |     the metric cannot be computed (e.g. due to no predictions made).
 38 | 
 39 |     In addition to COCO, this evaluator is able to support any bounding box detection,
 40 |     instance segmentation, or keypoint detection dataset.
 41 |     """
 42 | 
 43 |     def _eval_predictions(self, predictions, img_ids=None):
 44 |         """
 45 |         Evaluate predictions. Fill self._results with the metrics of the tasks.
 46 |         """
 47 |         self._logger.info("Preparing results for COCO format ...")
 48 |         coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
 49 |         tasks = self._tasks or self._tasks_from_predictions(coco_results)
 50 | 
 51 |         # unmap the category ids for COCO
 52 |         if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
 53 |             dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
 54 |             # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
 55 |             # num_classes = len(all_contiguous_ids)
 56 |             # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
 57 | 
 58 |             reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
 59 |             for result in coco_results:
 60 |                 category_id = result["category_id"]
 61 |                 # assert category_id < num_classes, (
 62 |                 #     f"A prediction has class={category_id}, "
 63 |                 #     f"but the dataset only has {num_classes} classes and "
 64 |                 #     f"predicted class id should be in [0, {num_classes - 1}]."
 65 |                 # )
 66 |                 assert category_id in reverse_id_mapping, (
 67 |                     f"A prediction has class={category_id}, "
 68 |                     f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
 69 |                 )
 70 |                 result["category_id"] = reverse_id_mapping[category_id]
 71 | 
 72 |         if self._output_dir:
 73 |             file_path = os.path.join(self._output_dir, "coco_instances_results.json")
 74 |             self._logger.info("Saving results to {}".format(file_path))
 75 |             with PathManager.open(file_path, "w") as f:
 76 |                 f.write(json.dumps(coco_results))
 77 |                 f.flush()
 78 | 
 79 |         if not self._do_evaluation:
 80 |             self._logger.info("Annotations are not available for evaluation.")
 81 |             return
 82 | 
 83 |         self._logger.info(
 84 |             "Evaluating predictions with {} COCO API...".format(
 85 |                 "unofficial" if self._use_fast_impl else "official"
 86 |             )
 87 |         )
 88 |         for task in sorted(tasks):
 89 |             assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
 90 |             coco_eval = (
 91 |                 _evaluate_predictions_on_coco(
 92 |                     self._coco_api,
 93 |                     coco_results,
 94 |                     task,
 95 |                     kpt_oks_sigmas=self._kpt_oks_sigmas,
 96 |                     use_fast_impl=self._use_fast_impl,
 97 |                     img_ids=img_ids,
 98 |                     max_dets_per_image=self._max_dets_per_image,
 99 |                 )
100 |                 if len(coco_results) > 0
101 |                 else None  # cocoapi does not handle empty results very well
102 |             )
103 | 
104 |             res = self._derive_coco_results(
105 |                 coco_eval, task, class_names=self._metadata.get("thing_classes")
106 |             )
107 |             self._results[task] = res
108 | 


--------------------------------------------------------------------------------
/datasets/evaluation/interactive_evaluation.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Facebook, Inc. and its affiliates.
  2 | import logging
  3 | import os
  4 | 
  5 | import numpy as np
  6 | import torch
  7 | from torchvision.ops import box_iou
  8 | 
  9 | from detectron2.structures import BoxMode
 10 | from detectron2.data import MetadataCatalog
 11 | from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
 12 | from detectron2.evaluation.evaluator import DatasetEvaluator
 13 | 
 14 | 
 15 | class InteractiveEvaluator(DatasetEvaluator):
 16 |     """
 17 |     Evaluate point interactive IoU metrics.
 18 |     """
 19 | 
 20 |     def __init__(
 21 |         self,
 22 |         dataset_name,
 23 |         output_dir,
 24 |         max_clicks=20,
 25 |         iou_iter=1,
 26 |         compute_box=False,
 27 |         distributed=True,
 28 |     ):
 29 |         self._logger = logging.getLogger(__name__)
 30 |         self._dataset_name = dataset_name
 31 |         self._distributed = distributed
 32 |         self._cpu_device = torch.device("cpu")
 33 |         self._output_dir = output_dir
 34 | 
 35 |         self.max_clicks = max_clicks
 36 |         self.iou_iter = iou_iter
 37 |         meta = MetadataCatalog.get(dataset_name)
 38 | 
 39 |     def reset(self):
 40 |         self.iou_list = []
 41 |         self.num_samples = 0
 42 |         self.all_ious = [0.5, 0.8, 0.85, 0.9]
 43 | 
 44 |     def process(self, inputs, outputs):
 45 |         self.iou_list += [o['mask_iou'] for o in outputs]
 46 |         self.num_samples += len(outputs)
 47 | 
 48 |     def compute_noc(self):
 49 |         def _get_noc(iou_arr, iou_thr):
 50 |             vals = iou_arr >= iou_thr
 51 |             return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
 52 | 
 53 |         noc_list = {}
 54 |         for iou_thr in self.all_ious:
 55 |             scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
 56 |             noc_list[str(iou_thr)] = scores_arr
 57 | 
 58 |         iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
 59 |         noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
 60 | 
 61 |         if self._distributed:
 62 |             num_samples = sum(all_gather(self.num_samples))
 63 |             noc_list_sum_gather = all_gather(noc_list_sum)
 64 |             iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
 65 | 
 66 |             noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
 67 |             for nlg in noc_list_sum_gather:
 68 |                 for key, value in nlg.items():
 69 |                     noc_list_sum[key] += value
 70 | 
 71 |         pred_noc = {}
 72 |         if self._distributed and (not is_main_process()):
 73 |             return pred_noc
 74 | 
 75 |         for key, value in noc_list_sum.items():
 76 |             pred_noc[key] = value / num_samples
 77 | 
 78 |         pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
 79 |         return pred_noc
 80 | 
 81 |     def evaluate(self):
 82 |         pred_noc = self.compute_noc()
 83 | 
 84 |         if self._distributed and (not is_main_process()):
 85 |             return
 86 | 
 87 |         def draw_iou_curve(iou_list, save_dir):
 88 |             iou_list = torch.stack(iou_list, dim=0)
 89 |             iou_list = iou_list.mean(dim=0).cpu().numpy()
 90 |             # draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
 91 |             import matplotlib.pyplot as plt
 92 |             plt.figure()
 93 |             plt.plot(range(1, self.max_clicks+1), iou_list)
 94 |             plt.xlabel('Number of clicks')
 95 |             plt.ylabel('IoU')
 96 | 
 97 | 
 98 |             # create directory if not exist
 99 |             import os
100 |             output_dir = os.path.join(save_dir, 'iou_by_clicks')
101 |             if not os.path.exists(output_dir):
102 |                 os.makedirs(output_dir)
103 | 
104 |             # get current time and format in 10 digits
105 |             import time
106 |             current_time = time.time()
107 |             current_time = int(current_time)
108 |             current_time = str(current_time)
109 | 
110 |             # save iou curve
111 |             plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
112 | 
113 |         draw_iou_curve(self.iou_list, self._output_dir)
114 |         results = {}
115 |         for idx in range(len(self.all_ious)):
116 |             result_str = 'noc@{}'.format(self.all_ious[idx])
117 |             results[result_str] = pred_noc[str(self.all_ious[idx])]
118 |         
119 |         results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
120 | 
121 |         self._logger.info(results)
122 |         return {'interactive': results}


--------------------------------------------------------------------------------
/datasets/registration/__init__.py:
--------------------------------------------------------------------------------
 1 | from . import (
 2 |     register_refcoco_dataset,
 3 |     register_ade20k_full,
 4 |     register_ade20k_panoptic,
 5 |     register_coco_stuff_10k,
 6 |     register_coco_panoptic_annos_semseg,
 7 |     register_coco_panoptic_annos_caption,
 8 |     register_coco_panoptic_annos_caption_grounding,
 9 |     register_coco_lvis_panoptic_annos_caption_grounding,
10 |     register_ade20k_instance,
11 |     register_vlp_datasets,
12 |     register_sunrgbd_semseg,
13 |     register_scannet_semseg,
14 |     register_bdd100k_semseg,
15 |     register_scannet_panoptic,
16 |     register_bdd100k_panoseg,
17 |     register_pascalvoc_eval,
18 | )


--------------------------------------------------------------------------------
/datasets/registration/register_ade20k_instance.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates.
 2 | import json
 3 | import logging
 4 | import numpy as np
 5 | import os
 6 | from PIL import Image
 7 | 
 8 | from detectron2.data import DatasetCatalog, MetadataCatalog
 9 | from detectron2.data.datasets.coco import load_coco_json, register_coco_instances
10 | from detectron2.utils.file_io import PathManager
11 | 
12 | ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}]
13 | 
14 | 
15 | _PREDEFINED_SPLITS = {
16 |     # point annotations without masks
17 |     "ade20k_instance_train": (
18 |         "ADEChallengeData2016/images/training",
19 |         "ADEChallengeData2016/ade20k_instance_train.json",
20 |     ),
21 |     "ade20k_instance_val": (
22 |         "ADEChallengeData2016/images/validation",
23 |         "ADEChallengeData2016/ade20k_instance_val.json",
24 |     ),
25 | }
26 | 
27 | 
28 | def _get_ade_instances_meta():
29 |     thing_ids = [k["id"] for k in ADE_CATEGORIES]
30 |     assert len(thing_ids) == 100, len(thing_ids)
31 |     # Mapping from the incontiguous ADE category id to an id in [0, 99]
32 |     thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
33 |     thing_classes = [k["name"] for k in ADE_CATEGORIES]
34 |     ret = {
35 |         "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
36 |         "thing_classes": thing_classes,
37 |     }
38 |     return ret
39 | 
40 | 
41 | def register_all_ade20k_instance(root):
42 |     for key, (image_root, json_file) in _PREDEFINED_SPLITS.items():
43 |         # Assume pre-defined datasets live in `./datasets`.
44 |         register_coco_instances(
45 |             key,
46 |             _get_ade_instances_meta(),
47 |             os.path.join(root, json_file) if "://" not in json_file else json_file,
48 |             os.path.join(root, image_root),
49 |         )
50 | 
51 | 
52 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
53 | register_all_ade20k_instance(_root)
54 | 


--------------------------------------------------------------------------------
/datasets/registration/register_bdd100k_semseg.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import numpy as np
 9 | import os
10 | import glob
11 | from typing import List, Tuple, Union
12 | 
13 | from detectron2.data import DatasetCatalog, MetadataCatalog
14 | from detectron2.utils.file_io import PathManager
15 | 
16 | from utils.constants import BDD_SEM
17 | 
18 | __all__ = ["load_scannet_instances", "register_scannet_context"]
19 | 
20 | 
21 | def load_bdd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
22 |     """
23 |     Load BDD annotations to Detectron2 format.
24 | 
25 |     Args:
26 |         dirname: Contain "Annotations", "ImageSets", "JPEGImages"
27 |         split (str): one of "train", "test", "val", "trainval"
28 |         class_names: list or tuple of class names
29 |     """
30 |     img_folder = os.path.join(dirname, 'images', '10k', split)
31 |     img_pths = sorted(glob.glob(os.path.join(img_folder, '*.jpg')))
32 |     
33 |     sem_folder = os.path.join(dirname, 'labels', 'sem_seg', 'masks', split)
34 |     sem_pths = sorted(glob.glob(os.path.join(sem_folder, '*.png')))
35 | 
36 |     assert len(img_pths) == len(sem_pths)
37 |         
38 |     dicts = []
39 |     for img_pth, sem_pth in zip(img_pths, sem_pths):
40 |         r = {
41 |             "file_name": img_pth,
42 |             "sem_seg_file_name": sem_pth,
43 |             "image_id": img_pth.split('/')[-1].split('.')[0],
44 |         }
45 |         dicts.append(r)
46 |     return dicts
47 | 
48 | 
49 | def register_bdd_context(name, dirname, split, class_names=BDD_SEM):
50 |     DatasetCatalog.register(name, lambda: load_bdd_instances(name, dirname, split, class_names))
51 |     MetadataCatalog.get(name).set(
52 |         stuff_classes=class_names,
53 |         dirname=dirname,
54 |         split=split,
55 |         ignore_label=[255],
56 |         thing_dataset_id_to_contiguous_id={},
57 |         class_offset=0,
58 |         keep_sem_bgd=False
59 |     )
60 | 
61 | 
62 | def register_all_sunrgbd_seg(root):
63 |     SPLITS = [
64 |             ("bdd10k_val_sem_seg", "bdd100k", "val"),
65 |         ]
66 |         
67 |     for name, dirname, split in SPLITS:
68 |         register_bdd_context(name, os.path.join(root, dirname), split)
69 |         MetadataCatalog.get(name).evaluator_type = "sem_seg"
70 | 
71 | 
72 | _root = os.getenv("DATASET", "datasets")
73 | register_all_sunrgbd_seg(_root)


--------------------------------------------------------------------------------
/datasets/registration/register_imagenet_cls.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import os
 9 | import glob
10 | from typing import List, Tuple, Union
11 | 
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 | from detectron2.utils.file_io import PathManager
15 | 
16 | from utils.constants import IMAGENET_CLASSES, IMAGENET_FOLDER_NAMES
17 | 
18 | __all__ = ["load_imagenet_images", "register_imagenet"]
19 | 
20 | 
21 | def load_imagenet_images(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
22 |     """
23 |     Load ImageNet annotations to Detectron2 format.
24 | 
25 |     Args:
26 |         dirname: Contain "Annotations", "ImageSets", "JPEGImages"
27 |         split (str): one of "train", "test", "val", "trainval"
28 |         class_names: list or tuple of class names
29 |     """
30 |     image_folders = sorted(glob.glob(os.path.join(dirname, split, 'n*')))
31 | 
32 |     dicts = []
33 |     for image_folder in image_folders:
34 |         folder_name = image_folder.split('/')[-1]
35 |         image_pths = sorted(glob.glob(os.path.join(image_folder, "*.JPEG")))
36 |         for img_pth in image_pths:
37 |             r = {
38 |                 "file_name": img_pth,
39 |                 "class_name": IMAGENET_CLASSES[IMAGENET_FOLDER_NAMES.index(folder_name)],
40 |                 "class_id": IMAGENET_FOLDER_NAMES.index(folder_name),
41 |             }
42 |             dicts.append(r)
43 |     return dicts
44 | 
45 | 
46 | def register_imagenet(name, dirname, split, year, class_names=IMAGENET_CLASSES):
47 |     DatasetCatalog.register(name, lambda: load_imagenet_images(dirname, split, class_names))
48 |     MetadataCatalog.get(name).set(
49 |         thing_classes=list(class_names), dirname=dirname, year=year, split=split
50 |     )
51 | 
52 | 
53 | def register_all_imagenet(root):
54 |     SPLITS = [
55 |             ("imagenet_val", "imagenet", "val", "2012"),
56 |         ]
57 |     for name, dirname, split, year in SPLITS:
58 |         register_imagenet(name, os.path.join(root, dirname), split, year)
59 |         MetadataCatalog.get(name).evaluator_type = "classification"
60 | 
61 | 
62 | _root = os.getenv("DATASET", "datasets")
63 | register_all_imagenet(_root)


--------------------------------------------------------------------------------
/datasets/registration/register_pascalvoc_eval.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) Facebook, Inc. and its affiliates.
 3 | import os
 4 | import glob
 5 | from typing import List, Tuple, Union
 6 | import xml.etree.ElementTree as ET
 7 | 
 8 | import cv2
 9 | import numpy as np
10 | from scipy.io import loadmat
11 | 
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 | from detectron2.utils.file_io import PathManager
15 | 
16 | 
17 | __all__ = ["load_pascalvoc_instances", "register_pascalvoc_context"]
18 | 
19 | def get_labels_with_sizes(x):
20 |     obj_sizes = np.bincount(x.flatten())
21 |     labels = np.nonzero(obj_sizes)[0].tolist()
22 |     labels = [x for x in labels if x != 0]
23 |     return labels, obj_sizes[labels].tolist()
24 | 
25 | def load_pascalvoc_instances(name: str, dirname: str, mode: str, split: str):
26 |     """
27 |     Load Pascal VOC detection annotations to Detectron2 format.
28 | 
29 |     Args:
30 |         dirname: Contain "Annotations", "ImageSets", "JPEGImages"
31 |         split (str): one of "train", "test", "val", "trainval"
32 |         class_names: list or tuple of class names
33 |     """
34 |     with PathManager.open(os.path.join(dirname, 'ImageSets', 'Segmentation', split + ".txt")) as f:
35 |         fileids = np.loadtxt(f, dtype=np.str)
36 | 
37 |     dicts = []
38 |     for field in fileids:
39 |         anno_path = os.path.join(dirname, "Annotations", "{}.xml".format(field))
40 |         image_path = os.path.join(dirname, "JPEGImages", "{}.jpg".format(field))
41 |         inst_path = os.path.join(dirname, "SegmentationObject", "{}.png".format(field))
42 |         semseg_path = os.path.join(dirname, "SegmentationClass", "{}.png".format(field))
43 | 
44 |         instances_mask = cv2.imread(inst_path)
45 |         instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32)
46 | 
47 |         objects_ids = np.unique(instances_mask)
48 |         objects_ids = [x for x in objects_ids if x != 0 and x != 220]
49 | 
50 |         slice_size = 5
51 |         for i in range(0, len(objects_ids), slice_size):
52 |             r = {
53 |                 "file_name": image_path,
54 |                 "inst_name": inst_path,
55 |                 "semseg_name": semseg_path,
56 |                 "objects_ids": objects_ids[i:i+slice_size],
57 |             }
58 |             dicts.append(r)
59 |     return dicts
60 | 
61 | def register_pascalvoc_context(name, dirname, mode, split):
62 |     DatasetCatalog.register("{}_{}".format(name, mode), lambda: load_pascalvoc_instances(name, dirname, mode, split))
63 |     MetadataCatalog.get("{}_{}".format(name, mode)).set(
64 |         dirname=dirname,
65 |         thing_dataset_id_to_contiguous_id={},
66 |     )
67 | 
68 | def register_all_sbd(root):
69 |     SPLITS = [
70 |             ("pascalvoc_val", "PascalVOC", "Point", "val"),
71 |             ("pascalvoc_val", "PascalVOC", "Scribble", "val"),
72 |             ("pascalvoc_val", "PascalVOC", "Polygon", "val"),
73 |             ("pascalvoc_val", "PascalVOC", "Circle", "val"),
74 |             ("pascalvoc_val", "PascalVOC", "Box", "val"),
75 |         ]
76 |         
77 |     for name, dirname, mode, split in SPLITS:
78 |         register_pascalvoc_context(name, os.path.join(root, dirname), mode, split)
79 |         MetadataCatalog.get("{}_{}".format(name, mode)).evaluator_type = "interactive"
80 | 
81 | _root = os.getenv("DATASET", "datasets")
82 | register_all_sbd(_root)


--------------------------------------------------------------------------------
/datasets/registration/register_refcoco_dataset.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | import json
  8 | import os
  9 | import collections
 10 | 
 11 | from detectron2.data import DatasetCatalog, MetadataCatalog
 12 | from detectron2.data.datasets import load_sem_seg
 13 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
 14 | from detectron2.utils.file_io import PathManager
 15 | 
 16 | 
 17 | _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION = {
 18 |     # "refcocog_train_umd": (
 19 |     #     "coco/train2017", # image_root
 20 |     #     "coco/annotations/refcocog_umd_train.json", # annot_root
 21 |     # ),
 22 |     # "refcocog_val_google": (
 23 |     #     "coco/train2017", # image_root
 24 |     #     "coco/annotations/refcocog_google.json", # annot_root
 25 |     # ),
 26 |     # "refcocop_val_unc": (
 27 |     #     "coco/train2017", # image_root
 28 |     #     "coco/annotations/refcocop_unc.json", # annot_root
 29 |     # ),
 30 |     # "refcoco_val_unc": (
 31 |     #     "coco/train2017", # image_root
 32 |     #     "coco/annotations/refcoco_unc.json", # annot_root
 33 |     # ),
 34 |     "refcocog_val_umd": (
 35 |         "coco/train2017", # image_root
 36 |         "coco/annotations/refcocog_umd_val.json", # annot_root
 37 |     ),
 38 | }
 39 | 
 40 | 
 41 | def get_metadata():
 42 |     meta = {}
 43 |     return meta
 44 | 
 45 | 
 46 | def load_refcoco_json(image_root, annot_json, metadata):
 47 |     """
 48 |     Args:
 49 |         image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
 50 |         gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
 51 |         json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
 52 |     Returns:
 53 |         list[dict]: a list of dicts in Detectron2 standard format. (See
 54 |         `Using Custom Datasets </tutorials/datasets.html>`_ )
 55 |     """
 56 | 
 57 |     with PathManager.open(annot_json) as f:
 58 |         json_info = json.load(f)
 59 |         
 60 |     # build dictionary for grounding
 61 |     grd_dict = collections.defaultdict(list)
 62 |     for grd_ann in json_info['annotations']:
 63 |         image_id = int(grd_ann["image_id"])
 64 |         grd_dict[image_id].append(grd_ann)
 65 | 
 66 |     ret = []
 67 |     for image in json_info["images"]:
 68 |         image_id = int(image["id"])
 69 |         image_file = os.path.join(image_root, image['file_name'])
 70 |         grounding_anno = grd_dict[image_id]
 71 |         ret.append(
 72 |             {
 73 |                 "file_name": image_file,
 74 |                 "image_id": image_id,
 75 |                 "grounding_info": grounding_anno,
 76 |             }
 77 |         )
 78 |     assert len(ret), f"No images found in {image_root}!"
 79 |     assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
 80 |     return ret
 81 | 
 82 | 
 83 | def register_refcoco(
 84 |     name, metadata, image_root, annot_json):
 85 |     DatasetCatalog.register(
 86 |         name,
 87 |         lambda: load_refcoco_json(image_root, annot_json, metadata),
 88 |     )
 89 |     MetadataCatalog.get(name).set(
 90 |         image_root=image_root,
 91 |         json_file=annot_json,
 92 |         evaluator_type="grounding_refcoco",
 93 |         ignore_label=255,
 94 |         label_divisor=1000,
 95 |         **metadata,
 96 |     )
 97 | 
 98 | 
 99 | def register_all_refcoco(root):
100 |     for (
101 |         prefix,
102 |         (image_root, annot_root),
103 |     ) in _PREDEFINED_SPLITS_COCO_PANOPTIC_CAPTION.items():
104 |         register_refcoco(
105 |             prefix,
106 |             get_metadata(),
107 |             os.path.join(root, image_root),
108 |             os.path.join(root, annot_root),
109 |         )
110 | 
111 | 
112 | _root = os.getenv("DATASET", "datasets")
113 | register_all_refcoco(_root)
114 | 


--------------------------------------------------------------------------------
/datasets/registration/register_scannet_semseg.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | import numpy as np
 8 | import os
 9 | import glob
10 | from typing import List, Tuple, Union
11 | 
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 | from detectron2.utils.file_io import PathManager
15 | 
16 | from utils.constants import SCAN_37, SCAN_40, SCAN_20
17 | 
18 | __all__ = ["load_scannet_instances", "register_scannet_context"]
19 | 
20 | name2folder = {"scannet_41_val_seg": "label41",
21 |                "scannet_38_val_seg": "label38",
22 |                "scannet_21_val_seg": "label21",}
23 | 
24 | name2class = {"scannet_41_val_seg": SCAN_40,
25 |               "scannet_38_val_seg": SCAN_37,
26 |               "scannet_21_val_seg": SCAN_20}
27 | 
28 | 
29 | def load_scannet_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
30 |     """
31 |     Load ScanNet annotations to Detectron2 format.
32 | 
33 |     Args:
34 |         dirname: Contain "Annotations", "ImageSets", "JPEGImages"
35 |         split (str): one of "train", "test", "val", "trainval"
36 |         class_names: list or tuple of class names
37 |     """
38 |     with PathManager.open(os.path.join(dirname, "meta", split + ".txt")) as f:
39 |         fileids = np.loadtxt(f, dtype=np.str)
40 |         
41 |     dicts = []
42 |     for field in fileids:
43 |         image_dir = os.path.join(dirname, 'images', field[0])
44 |         semseg_dir = image_dir.replace('color', name2folder[name]).replace('jpg', 'png')
45 |         r = {
46 |             "file_name": image_dir,
47 |             "sem_seg_file_name": semseg_dir,
48 |             "image_id": semseg_dir.split('/')[-3] + semseg_dir.split('/')[-1].split('.')[0],
49 |         }
50 |         dicts.append(r)
51 |     return dicts
52 | 
53 | 
54 | def register_scannet_context(name, dirname, split, class_names=name2class):
55 |     DatasetCatalog.register(name, lambda: load_scannet_instances(name, dirname, split, class_names))
56 |     MetadataCatalog.get(name).set(
57 |         stuff_classes=class_names[name],
58 |         dirname=dirname,
59 |         split=split,
60 |         ignore_label=[0],
61 |         thing_dataset_id_to_contiguous_id={},
62 |         class_offset=1,
63 |         keep_sem_bgd=False
64 |     )
65 | 
66 | 
67 | def register_all_sunrgbd_seg(root):
68 |     SPLITS = [
69 |             ("scannet_41_val_seg", "scannet_frames_25k", "val"),
70 |             ("scannet_38_val_seg", "scannet_frames_25k", "val"),
71 |             ("scannet_21_val_seg", "scannet_frames_25k", "val"),
72 |         ]
73 |         
74 |     for name, dirname, split in SPLITS:
75 |         register_scannet_context(name, os.path.join(root, dirname), split)
76 |         MetadataCatalog.get(name).evaluator_type = "sem_seg"
77 | 
78 | 
79 | _root = os.getenv("DATASET", "datasets")
80 | register_all_sunrgbd_seg(_root)


--------------------------------------------------------------------------------
/datasets/registration/register_sunrgbd_semseg.py:
--------------------------------------------------------------------------------
 1 | 
 2 | # --------------------------------------------------------
 3 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 4 | # Copyright (c) 2022 Microsoft
 5 | # Licensed under The MIT License [see LICENSE for details]
 6 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 7 | # --------------------------------------------------------
 8 | import numpy as np
 9 | import os
10 | import glob
11 | from typing import List, Tuple, Union
12 | 
13 | from detectron2.data import DatasetCatalog, MetadataCatalog
14 | from detectron2.structures import BoxMode
15 | from detectron2.utils.file_io import PathManager
16 | 
17 | from utils.constants import SUN_RGBD_37
18 | 
19 | __all__ = ["load_sunrgbd_instances", "register_sunrgbd_context"]
20 | 
21 | def load_sunrgbd_instances(name: str, dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
22 |     """
23 |     Load SUN-RGBD detection annotations to Detectron2 format.
24 | 
25 |     Args:
26 |         dirname: Contain "Annotations", "ImageSets", "JPEGImages"
27 |         split (str): one of "train", "test", "val", "trainval"
28 |         class_names: list or tuple of class names
29 |     """
30 |     if split == 'val':
31 |         split = 'test'
32 |         
33 |     # Needs to read many small annotation files. Makes sense at local
34 |     image_pths = sorted(glob.glob(os.path.join(dirname, 'image', split, '*.jpg')))
35 |     semseg_pths = sorted(glob.glob(os.path.join(dirname, 'label37', split, '*.png')))
36 |     
37 |     assert len(image_pths) == len(semseg_pths)
38 |     
39 |     dicts = []
40 |     for image_dir, semseg_dir in zip(image_pths, semseg_pths):
41 |         r = {
42 |             "file_name": image_dir,
43 |             "sem_seg_file_name": semseg_dir,
44 |             "image_id": semseg_dir.split('/')[-1].split('.')[0],
45 |         }
46 |         dicts.append(r)
47 |     return dicts
48 | 
49 | 
50 | def register_sun_context(name, dirname, split, class_names=SUN_RGBD_37):
51 |     DatasetCatalog.register(name, lambda: load_sunrgbd_instances(name, dirname, split, class_names))
52 |     MetadataCatalog.get(name).set(
53 |         stuff_classes=class_names,
54 |         dirname=dirname,
55 |         split=split,
56 |         ignore_label=[0],
57 |         thing_dataset_id_to_contiguous_id={},
58 |         class_offset=1,
59 |         keep_sem_bgd=False
60 |     )
61 | 
62 | 
63 | def register_all_sunrgbd_seg(root):
64 |     SPLITS = [
65 |             ("sunrgbd_37_val_seg", "sun_rgbd", "val"),
66 |         ]
67 |         
68 |     for name, dirname, split in SPLITS:
69 |         register_sun_context(name, os.path.join(root, dirname), split)
70 |         MetadataCatalog.get(name).evaluator_type = "sem_seg"
71 | 
72 | 
73 | _root = os.getenv("DATASET", "datasets")
74 | register_all_sunrgbd_seg(_root)


--------------------------------------------------------------------------------
/datasets/semseg_loader.py:
--------------------------------------------------------------------------------
 1 | from PIL import Image
 2 | import scipy.io
 3 | import numpy as np
 4 | 
 5 | def load_semseg(filename, loader_type):
 6 |     if loader_type == 'PIL':
 7 |         semseg = np.array(Image.open(filename), dtype=np.int)
 8 |     elif loader_type == 'MAT':
 9 |         semseg = scipy.io.loadmat(filename)['LabelMap']
10 |     return semseg


--------------------------------------------------------------------------------
/datasets/utils/refcoco2json.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import json
 3 | from refer import REFER
 4 | 
 5 | coco_root = '/pth/to/coco'
 6 | ref_root = '/pth/to/refcocoseg'
 7 | 
 8 | coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json')))
 9 | coco_train_id = []
10 | image_annot = {}
11 | for i in range(len(coco_train_annot['images'])):
12 |     coco_train_id.append(coco_train_annot['images'][i]['id'])
13 |     image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i]
14 | 
15 | refg = REFER(data_root=ref_root,
16 |                 dataset='refcocog', splitBy='umd')
17 | refg_val_ids = refg.getRefIds(split='val')
18 | 
19 | full_anno = []
20 | for ref_id in refg_val_ids:
21 |     ref = refg.loadRefs(ref_id)[0]
22 |     anno = refg.refToAnn[ref_id]
23 |     anno.update(ref)
24 |     full_anno.append(anno)
25 | 
26 | imageid_list = []
27 | final_anno = {}
28 | for anno in full_anno:
29 |     imageid_list += [anno['image_id']]
30 |     final_anno[anno['ann_id']] = anno
31 |     
32 | annotations = [value for key, value in final_anno.items()]
33 | 
34 | iamges = []
35 | for image_id in list(set(imageid_list)):
36 |     iamges += [image_annot[image_id]]
37 | 
38 | outputs = {'images': iamges, 'annotations': annotations}
39 | print(len(iamges))
40 | print(len(annotations))
41 | json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w'))
42 | 


--------------------------------------------------------------------------------
/datasets/visual_sampler/__init__.py:
--------------------------------------------------------------------------------
 1 | from .sampler import ShapeSampler
 2 | from .simpleclick_sampler import SimpleClickSampler
 3 | 
 4 | 
 5 | def build_shape_sampler(cfg, **kwargs):
 6 |     sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE']
 7 |     if sampler_name == 'random':
 8 |         return ShapeSampler(cfg, **kwargs)
 9 |     elif sampler_name in ['best', 'best_random']:
10 |         return SimpleClickSampler(cfg, **kwargs)
11 |     else:
12 |         assert False, "not implemented"


--------------------------------------------------------------------------------
/datasets/visual_sampler/circle.py:
--------------------------------------------------------------------------------
  1 | import random
  2 | import torch
  3 | 
  4 | from .mask_generators import get_mask_by_input_strokes
  5 | 
  6 | class Circle:
  7 |     def __init__(self, cfg, is_train=True):
  8 |         self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES']
  9 |         self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET']
 10 |         self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB']
 11 |         self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
 12 |         self.is_train = is_train
 13 | 
 14 |     @staticmethod
 15 |     def get_stroke_preset(stroke_preset):
 16 |         if stroke_preset == 'object_like':
 17 |             return {
 18 |                 "nVertexBound": [5, 30],
 19 |                 "maxHeadSpeed": 15,
 20 |                 "maxHeadAcceleration": (10, 1.5),
 21 |                 "brushWidthBound": (20, 50),
 22 |                 "nMovePointRatio": 0.5,
 23 |                 "maxPiontMove": 10,
 24 |                 "maxLineAcceleration": (5, 0.5),
 25 |                 "boarderGap": None,
 26 |                 "maxInitSpeed": 10,
 27 |             }
 28 |         elif stroke_preset == 'object_like_middle':
 29 |             return {
 30 |                 "nVertexBound": [5, 15],
 31 |                 "maxHeadSpeed": 8,
 32 |                 "maxHeadAcceleration": (4, 1.5),
 33 |                 "brushWidthBound": (20, 50),
 34 |                 "nMovePointRatio": 0.5,
 35 |                 "maxPiontMove": 5,
 36 |                 "maxLineAcceleration": (5, 0.5),
 37 |                 "boarderGap": None,
 38 |                 "maxInitSpeed": 10,
 39 |             }
 40 |         elif stroke_preset == 'object_like_small':
 41 |             return {
 42 |                 "nVertexBound": [5, 20],
 43 |                 "maxHeadSpeed": 7,
 44 |                 "maxHeadAcceleration": (3.5, 1.5),
 45 |                 "brushWidthBound": (10, 30),
 46 |                 "nMovePointRatio": 0.5,
 47 |                 "maxPiontMove": 5,
 48 |                 "maxLineAcceleration": (3, 0.5),
 49 |                 "boarderGap": None,
 50 |                 "maxInitSpeed": 4,
 51 |             }
 52 |         else:
 53 |             raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
 54 | 
 55 |     def get_random_points_from_mask(self, mask, n=5):
 56 |         h,w = mask.shape
 57 |         view_mask = mask.reshape(h*w)
 58 |         non_zero_idx = view_mask.nonzero()[:,0]
 59 |         selected_idx = torch.randperm(len(non_zero_idx))[:n]
 60 |         non_zero_idx = non_zero_idx[selected_idx]
 61 |         y = (non_zero_idx // w)*1.0
 62 |         x = (non_zero_idx % w)*1.0
 63 |         return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
 64 | 
 65 |     def draw(self, mask=None, box=None):
 66 |         if mask.sum() < 10: # if mask is nearly empty
 67 |             return torch.zeros(mask.shape).bool()
 68 |         if not self.is_train:
 69 |             return self.draw_eval(mask=mask, box=box)
 70 |         stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
 71 |         preset = Circle.get_stroke_preset(stroke_preset_name)
 72 |         nStroke = min(random.randint(1, self.num_stroke), mask.sum().item())
 73 |         h,w = mask.shape
 74 |         points = self.get_random_points_from_mask(mask, n=nStroke)
 75 |         rand_mask = get_mask_by_input_strokes(
 76 |             init_points=points,
 77 |             imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
 78 |         rand_mask = (~torch.from_numpy(rand_mask)) * mask
 79 |         return rand_mask
 80 | 
 81 |     def draw_eval(self, mask=None, box=None):
 82 |         stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
 83 |         preset = Circle.get_stroke_preset(stroke_preset_name)
 84 |         nStroke = min(self.max_eval, mask.sum().item())
 85 |         h,w = mask.shape
 86 |         points = self.get_random_points_from_mask(mask, n=nStroke)
 87 |         rand_masks = []
 88 |         for i in range(len(points)):
 89 |             rand_mask = get_mask_by_input_strokes(
 90 |                 init_points=points[:i+1],
 91 |                 imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset)
 92 |             rand_masks += [(~torch.from_numpy(rand_mask)) * mask]
 93 |         return torch.stack(rand_masks)
 94 | 
 95 |     @staticmethod
 96 |     def draw_by_points(points, mask, h, w):
 97 |         stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use
 98 |         preset = Circle.get_stroke_preset(stroke_preset_name)
 99 |         rand_mask = get_mask_by_input_strokes(
100 |             init_points=points,
101 |             imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
102 |         rand_masks = (~torch.from_numpy(rand_mask)) * mask
103 |         return rand_masks
104 | 
105 |     def __repr__(self,):
106 |         return 'circle'


--------------------------------------------------------------------------------
/datasets/visual_sampler/point.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | import torch
 3 | import torch.nn.functional as F
 4 | import numpy as np
 5 | from scipy import ndimage
 6 | 
 7 | 
 8 | class Point:
 9 |     def __init__(self, cfg, is_train=True):
10 |         self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
11 |         self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
12 |         self.is_train = is_train
13 | 
14 |     def draw(self, mask=None, box=None):
15 |         if mask.sum() < 10:
16 |             return torch.zeros(mask.shape).bool() # if mask is empty
17 |         if not self.is_train:
18 |             return self.draw_eval(mask=mask, box=box)
19 |         max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
20 |         num_points = random.randint(1, max_points) # get a random number of points 
21 |         h,w = mask.shape
22 |         view_mask = mask.view(-1)
23 |         non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
24 |         selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
25 |         non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
26 |         rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
27 |         rand_mask[non_zero_idx] = True # get non zero place to zero
28 |         # dilate
29 |         # struct = ndimage.generate_binary_structure(2, 2)
30 |         # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
31 |         # return rand_mask
32 |         return rand_mask.reshape(h, w)
33 |     
34 |     def draw_eval(self, mask=None, box=None):
35 |         background = ~mask
36 |         neg_num = min(self.max_eval // 2, background.sum().item())
37 |         pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
38 | 
39 |         h,w = mask.shape
40 |         view_mask = mask.view(-1)
41 |         non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
42 |         selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
43 |         non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
44 |         pos_idx = torch.ones(non_zero_idx_pos.shape)
45 | 
46 |         view_background = background.view(-1)
47 |         non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
48 |         selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
49 |         non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
50 |         neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
51 | 
52 |         non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
53 |         idx = torch.cat([pos_idx, neg_idx])
54 |         rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
55 |         non_zero_idx = non_zero_idx[rand_idx]
56 |         idx = idx[rand_idx]
57 | 
58 |         rand_masks = []
59 |         for i in range(0, len(non_zero_idx)):
60 |             rand_mask = torch.zeros(view_mask.shape) # init rand mask
61 |             rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
62 |             # struct = ndimage.generate_binary_structure(2, 2)
63 |             # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
64 |             rand_masks += [rand_mask.reshape(h, w)]
65 | 
66 |         # kernel_size = 3
67 |         rand_masks = torch.stack(rand_masks)
68 |         # rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
69 |         # rand_masks[rand_masks>0] = 1
70 |         # rand_masks[rand_masks<0] = -1
71 |         return rand_masks
72 |     
73 |     def __repr__(self,):
74 |         return 'point'


--------------------------------------------------------------------------------
/datasets/visual_sampler/sampler.py:
--------------------------------------------------------------------------------
 1 | import sys
 2 | import random
 3 | 
 4 | import torch
 5 | import torch.nn as nn
 6 | 
 7 | from .point import Point
 8 | from .polygon import Polygon
 9 | from .scribble import Scribble
10 | from .circle import Circle
11 | 
12 | from modeling.utils import configurable
13 | 
14 | 
15 | class ShapeSampler(nn.Module):
16 |     @configurable
17 |     def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
18 |         super().__init__()
19 |         self.max_candidate = max_candidate
20 |         self.shape_prob = shape_prob
21 |         self.shape_candidate = shape_candidate
22 |         self.is_train = is_train
23 | 
24 |     @classmethod
25 |     def from_config(cls, cfg, is_train=True, mode=None):
26 |         max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
27 |         candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
28 |         candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
29 | 
30 |         if mode == 'hack_train':
31 |             candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]        
32 |         else:
33 |             # overwrite condidate_prob
34 |             if not is_train:
35 |                 candidate_probs = [0.0 for x in range(len(candidate_names))]
36 |                 candidate_probs[candidate_names.index(mode)] = 1.0
37 |             candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
38 | 
39 |         # Build augmentation
40 |         return {
41 |             "max_candidate": max_candidate,
42 |             "shape_prob": candidate_probs,
43 |             "shape_candidate": candidate_classes,
44 |             "is_train": is_train,
45 |         }
46 | 
47 |     def forward(self, instances):
48 |         masks = instances.gt_masks.tensor
49 |         boxes = instances.gt_boxes.tensor
50 | 
51 |         if len(masks) == 0:
52 |             gt_masks = torch.zeros(masks.shape[-2:]).bool()
53 |             rand_masks = torch.zeros(masks.shape[-2:]).bool()
54 |             return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
55 |         indices = [x for x in range(len(masks))]
56 |  
57 |         if self.is_train:
58 |             random.shuffle(indices)
59 |             candidate_mask = masks[indices[:self.max_candidate]]
60 |             candidate_box = boxes[indices[:self.max_candidate]]
61 |         else:
62 |             candidate_mask = masks
63 |             candidate_box = boxes
64 |         
65 |         draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
66 |         rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
67 |         types = [repr(x) for x in draw_funcs]
68 |         for i in range(0, len(rand_shapes)):
69 |             if rand_shapes[i].sum() == 0:
70 |                 candidate_mask[i] = candidate_mask[i] * 0
71 |                 types[i] = 'none'
72 | 
73 |         # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
74 |         return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
75 | 
76 | def build_shape_sampler(cfg, **kwargs):
77 |     return ShapeSampler(cfg, **kwargs)


--------------------------------------------------------------------------------
/datasets/visual_sampler/scribble.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | 
 3 | import torch
 4 | 
 5 | from .mask_generators import get_mask_by_input_strokes
 6 | 
 7 | class Scribble:
 8 |     def __init__(self, cfg, is_train):
 9 |         self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES']
10 |         self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET']
11 |         self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB']
12 |         self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
13 |         self.is_train = is_train
14 | 
15 |     @staticmethod
16 |     def get_stroke_preset(stroke_preset):
17 |         if stroke_preset == 'rand_curve':
18 |             return {
19 |                 "nVertexBound": [10, 30],
20 |                 "maxHeadSpeed": 20,
21 |                 "maxHeadAcceleration": (15, 0.5),
22 |                 "brushWidthBound": (3, 10),
23 |                 "nMovePointRatio": 0.5,
24 |                 "maxPiontMove": 3,
25 |                 "maxLineAcceleration": (5, 0.5),
26 |                 "boarderGap": None,
27 |                 "maxInitSpeed": 6
28 |             }
29 |         elif stroke_preset == 'rand_curve_small':
30 |             return {
31 |                 "nVertexBound": [6, 22],
32 |                 "maxHeadSpeed": 12,
33 |                 "maxHeadAcceleration": (8, 0.5),
34 |                 "brushWidthBound": (2.5, 5),
35 |                 "nMovePointRatio": 0.5,
36 |                 "maxPiontMove": 1.5,
37 |                 "maxLineAcceleration": (3, 0.5),
38 |                 "boarderGap": None,
39 |                 "maxInitSpeed": 3
40 |             }
41 |         else:
42 |             raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
43 | 
44 |     def get_random_points_from_mask(self, mask, n=5):
45 |         h,w = mask.shape
46 |         view_mask = mask.reshape(h*w)
47 |         non_zero_idx = view_mask.nonzero()[:,0]
48 |         selected_idx = torch.randperm(len(non_zero_idx))[:n]
49 |         non_zero_idx = non_zero_idx[selected_idx]
50 |         y = (non_zero_idx // w)*1.0
51 |         x = (non_zero_idx % w)*1.0
52 |         return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
53 | 
54 |     def draw(self, mask=None, box=None):
55 |         if mask.sum() < 10:
56 |             return torch.zeros(mask.shape).bool() # if mask is empty
57 |         if not self.is_train:
58 |             return self.draw_eval(mask=mask, box=box)
59 |         stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
60 |         preset = Scribble.get_stroke_preset(stroke_preset_name)
61 |         nStroke = random.randint(1, min(self.num_stroke, mask.sum().item()))
62 |         h,w = mask.shape
63 |         points = self.get_random_points_from_mask(mask, n=nStroke)
64 |         rand_mask = get_mask_by_input_strokes(
65 |             init_points=points,
66 |             imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
67 |         rand_mask = (~torch.from_numpy(rand_mask)) * mask
68 |         return rand_mask
69 | 
70 |     def draw_eval(self, mask=None, box=None):
71 |         stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
72 |         preset = Scribble.get_stroke_preset(stroke_preset_name)
73 |         nStroke = min(self.eval_stroke, mask.sum().item())
74 |         h,w = mask.shape
75 |         points = self.get_random_points_from_mask(mask, n=nStroke)
76 |         rand_masks = []
77 |         for i in range(len(points)):
78 |             rand_mask = get_mask_by_input_strokes(
79 |                 init_points=points[:i+1],
80 |                 imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset)
81 |             rand_mask = (~torch.from_numpy(rand_mask)) * mask
82 |             rand_masks += [rand_mask]
83 |         return torch.stack(rand_masks)
84 | 
85 |     @staticmethod
86 |     def draw_by_points(points, mask, h, w):
87 |         stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0]
88 |         preset = Scribble.get_stroke_preset(stroke_preset_name)
89 |         rand_mask = get_mask_by_input_strokes(
90 |             init_points=points,
91 |             imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
92 |         rand_masks = (~torch.from_numpy(rand_mask)) * mask
93 |         return rand_masks
94 | 
95 |     def __repr__(self,):
96 |         return 'scribble'


--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/__init__.py


--------------------------------------------------------------------------------
/demo/seem/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/__init__.py


--------------------------------------------------------------------------------
/demo/seem/examples/corgi1.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/corgi1.webp


--------------------------------------------------------------------------------
/demo/seem/examples/corgi2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/corgi2.jpg


--------------------------------------------------------------------------------
/demo/seem/examples/fries1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/fries1.png


--------------------------------------------------------------------------------
/demo/seem/examples/fries2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/fries2.png


--------------------------------------------------------------------------------
/demo/seem/examples/minecraft1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/minecraft1.jpg


--------------------------------------------------------------------------------
/demo/seem/examples/placeholder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/placeholder.png


--------------------------------------------------------------------------------
/demo/seem/examples/ref_vase.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/ref_vase.JPG


--------------------------------------------------------------------------------
/demo/seem/examples/river1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/river1.png


--------------------------------------------------------------------------------
/demo/seem/examples/river1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/river1.wav


--------------------------------------------------------------------------------
/demo/seem/examples/river1_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/river1_mask.png


--------------------------------------------------------------------------------
/demo/seem/examples/river2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/river2.png


--------------------------------------------------------------------------------
/demo/seem/examples/vasedeck.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/vasedeck.mp4


--------------------------------------------------------------------------------
/demo/seem/examples/zebras1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/zebras1.jpg


--------------------------------------------------------------------------------
/demo/seem/examples/zebras2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/demo/seem/examples/zebras2.jpg


--------------------------------------------------------------------------------
/demo/seem/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .interactive import interactive_infer_video, interactive_infer_image


--------------------------------------------------------------------------------
/entry.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import os
 9 | import sys
10 | import torch
11 | import logging
12 | import wandb
13 | 
14 | from utils.arguments import load_opt_command
15 | 
16 | logging.basicConfig(level=logging.INFO)
17 | logger = logging.getLogger(__name__)
18 | 
19 | def init_wandb(args, job_dir, entity='xueyanz', project='xdecoder', job_name='tmp'):
20 |     wandb_dir = os.path.join(job_dir, 'wandb')
21 |     os.makedirs(wandb_dir, exist_ok=True)
22 |     runid = None
23 |     if os.path.exists(f"{wandb_dir}/runid.txt"):
24 |         runid = open(f"{wandb_dir}/runid.txt").read()
25 | 
26 |     wandb.init(project=project,
27 |             name=job_name,
28 |             dir=wandb_dir,
29 |             entity=entity,
30 |             resume="allow",
31 |             id=runid,
32 |             config={"hierarchical": True},)
33 | 
34 |     open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id)
35 |     wandb.config.update({k: args[k] for k in args if k not in wandb.config})
36 | 
37 | def main(args=None):
38 |     '''
39 |     [Main function for the entry point]
40 |     1. Set environment variables for distributed training.
41 |     2. Load the config file and set up the trainer.
42 |     '''
43 | 
44 |     opt, cmdline_args = load_opt_command(args)
45 |     command = cmdline_args.command
46 | 
47 |     if cmdline_args.user_dir:
48 |         absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
49 |         opt['base_path'] = absolute_user_dir
50 | 
51 |     # update_opt(opt, command)
52 |     world_size = 1
53 |     if 'OMPI_COMM_WORLD_SIZE' in os.environ:
54 |         world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
55 | 
56 |     if opt['TRAINER'] == 'xdecoder':
57 |         from trainer import XDecoder_Trainer as Trainer
58 |     else:
59 |         assert False, "The trainer type: {} is not defined!".format(opt['TRAINER'])
60 |     
61 |     trainer = Trainer(opt)
62 |     os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL'
63 | 
64 |     if command == "train":
65 |         if opt['rank'] == 0 and opt['WANDB']:
66 |             wandb.login(key=os.environ['WANDB_KEY'])
67 |             init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder)
68 |         trainer.train()
69 |     elif command == "evaluate":
70 |         trainer.eval()
71 |     else:
72 |         raise ValueError(f"Unknown command: {command}")
73 | 
74 | if __name__ == "__main__":
75 |     main()
76 |     sys.exit(0)
77 | 


--------------------------------------------------------------------------------
/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/__init__.py


--------------------------------------------------------------------------------
/inference/images/animals.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/animals.png


--------------------------------------------------------------------------------
/inference/images/apples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/apples.jpg


--------------------------------------------------------------------------------
/inference/images/coco/000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/coco/000.jpg


--------------------------------------------------------------------------------
/inference/images/coco/001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/coco/001.jpg


--------------------------------------------------------------------------------
/inference/images/coco/002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/coco/002.jpg


--------------------------------------------------------------------------------
/inference/images/coco/003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/coco/003.jpg


--------------------------------------------------------------------------------
/inference/images/fruit.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/fruit.jpg


--------------------------------------------------------------------------------
/inference/images/landscape.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/landscape.jpg


--------------------------------------------------------------------------------
/inference/images/mountain.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/mountain.jpeg


--------------------------------------------------------------------------------
/inference/images/owls.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/owls.jpeg


--------------------------------------------------------------------------------
/inference/images/penguin.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/penguin.jpeg


--------------------------------------------------------------------------------
/inference/images/region_retrieval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/region_retrieval.png


--------------------------------------------------------------------------------
/inference/images/rose.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/rose.webp


--------------------------------------------------------------------------------
/inference/images/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/street.jpg


--------------------------------------------------------------------------------
/inference/images/teaser_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/inference/images/teaser_new.png


--------------------------------------------------------------------------------
/inference/xdecoder/infer_captioning.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import os
 9 | import sys
10 | import logging
11 | 
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 | 
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(0)
18 | import cv2
19 | 
20 | import torch
21 | from torchvision import transforms
22 | 
23 | from utils.arguments import load_opt_command
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.structures import BitMasks
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from detectron2.utils.colormap import random_color
29 | from utils.visualizer import Visualizer
30 | from utils.distributed import init_distributed
31 | 
32 | logger = logging.getLogger(__name__)
33 | 
34 | 
35 | def main(args=None):
36 |     '''
37 |     Main execution point for PyLearn.
38 |     '''
39 |     
40 |     opt, cmdline_args = load_opt_command(args)
41 |     if cmdline_args.user_dir:
42 |         absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
43 |         opt['base_path'] = absolute_user_dir
44 |     opt = init_distributed(opt)
45 | 
46 |     # META DATA
47 |     pretrained_pth = os.path.join(opt['RESUME_FROM'])
48 |     if 'novg' not in pretrained_pth:
49 |         assert False, "Using the ckpt without visual genome training data will be much better."
50 |     output_root = './output'
51 |     image_pth = 'inference/images/mountain.jpeg'
52 | 
53 |     model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
54 |     model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background"], is_eval=False)
55 | 
56 |     t = []
57 |     t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
58 |     transform = transforms.Compose(t)
59 | 
60 |     with torch.no_grad():
61 |         image_ori = Image.open(image_pth).convert("RGB")
62 |         width = image_ori.size[0]
63 |         height = image_ori.size[1]
64 |         image = transform(image_ori)
65 |         image = np.asarray(image)
66 |         image_ori = np.asarray(image_ori)
67 |         images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
68 | 
69 |         batch_inputs = [{'image': images, 'height': height, 'width': width, 'image_id': 0}]
70 |         outputs = model.model.evaluate_captioning(batch_inputs)
71 |         text = outputs[-1]['captioning_text']
72 | 
73 |         image_ori = image_ori[:,:,::-1].copy()
74 |         cv2.rectangle(image_ori, (0, 0), (width, 60), (0,0,0), -1)
75 |         font                   = cv2.FONT_HERSHEY_DUPLEX
76 |         fontScale              = 1.2
77 |         thickness              = 2
78 |         lineType               = 2
79 |         bottomLeftCornerOfText = (10, 40)
80 |         fontColor              = [255,255,255]
81 |         cv2.putText(image_ori, text,
82 |             bottomLeftCornerOfText,
83 |             font, 
84 |             fontScale,
85 |             fontColor,
86 |             thickness,
87 |             lineType)
88 | 
89 |         if not os.path.exists(output_root):
90 |             os.makedirs(output_root)
91 |         cv2.imwrite(os.path.join(output_root, 'captioning.png'), image_ori)
92 | 
93 | 
94 | if __name__ == "__main__":
95 |     main()
96 |     sys.exit(0)


--------------------------------------------------------------------------------
/inference/xdecoder/infer_instseg.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import os
 9 | import sys
10 | import logging
11 | 
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 | 
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(2)
18 | 
19 | import torch
20 | from torchvision import transforms
21 | 
22 | from utils.arguments import load_opt_command
23 | 
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.structures import BitMasks
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from detectron2.utils.colormap import random_color
29 | from utils.visualizer import Visualizer
30 | from utils.distributed import init_distributed
31 | 
32 | logger = logging.getLogger(__name__)
33 | 
34 | 
35 | def main(args=None):
36 |     '''
37 |     Main execution point for PyLearn.
38 |     '''
39 |     opt, cmdline_args = load_opt_command(args)
40 |     if cmdline_args.user_dir:
41 |         absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
42 |         opt['base_path'] = absolute_user_dir
43 | 
44 |     opt = init_distributed(opt)
45 | 
46 |     # META DATA
47 |     pretrained_pth = os.path.join(opt['RESUME_FROM'])
48 |     output_root = './output'
49 |     image_pth = 'inference/images/owls.jpeg'
50 | 
51 |     model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
52 | 
53 |     t = []
54 |     t.append(transforms.Resize(800, interpolation=Image.BICUBIC))
55 |     transform = transforms.Compose(t)
56 | 
57 |     thing_classes = ["owl"]
58 |     thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))]
59 |     thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}
60 | 
61 |     MetadataCatalog.get("demo").set(
62 |         thing_colors=thing_colors,
63 |         thing_classes=thing_classes,
64 |         thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
65 |     )
66 |     model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=False)
67 |     metadata = MetadataCatalog.get('demo')
68 |     model.model.metadata = metadata
69 |     model.model.sem_seg_head.num_classes = len(thing_classes)
70 | 
71 |     with torch.no_grad():
72 |         image_ori = Image.open(image_pth).convert('RGB')
73 |         width = image_ori.size[0]
74 |         height = image_ori.size[1]
75 |         image = transform(image_ori)
76 |         image = np.asarray(image)
77 |         image_ori = np.asarray(image_ori)
78 |         images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
79 | 
80 |         batch_inputs = [{'image': images, 'height': height, 'width': width}]
81 |         outputs = model.forward(batch_inputs)
82 |         visual = Visualizer(image_ori, metadata=metadata)
83 | 
84 |         inst_seg = outputs[-1]['instances']
85 |         inst_seg.pred_masks = inst_seg.pred_masks.cpu()
86 |         inst_seg.pred_boxes = BitMasks(inst_seg.pred_masks > 0).get_bounding_boxes()
87 |         demo = visual.draw_instance_predictions(inst_seg) # rgb Image
88 | 
89 |         if not os.path.exists(output_root):
90 |             os.makedirs(output_root)
91 |         demo.save(os.path.join(output_root, 'inst.png'))
92 | 
93 | 
94 | if __name__ == "__main__":
95 |     main()
96 |     sys.exit(0)


--------------------------------------------------------------------------------
/inference/xdecoder/infer_panoseg.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | 
  8 | import os
  9 | import sys
 10 | import logging
 11 | 
 12 | pth = '/'.join(sys.path[0].split('/')[:-1])
 13 | sys.path.insert(0, pth)
 14 | 
 15 | from PIL import Image
 16 | import numpy as np
 17 | np.random.seed(1)
 18 | 
 19 | import torch
 20 | from torchvision import transforms
 21 | 
 22 | from utils.arguments import load_opt_command
 23 | 
 24 | from detectron2.data import MetadataCatalog
 25 | from detectron2.utils.colormap import random_color
 26 | from modeling.BaseModel import BaseModel
 27 | from modeling import build_model
 28 | from utils.visualizer import Visualizer
 29 | from utils.distributed import init_distributed
 30 | 
 31 | logger = logging.getLogger(__name__)
 32 | 
 33 | 
 34 | def main(args=None):
 35 |     '''
 36 |     Main execution point for PyLearn.
 37 |     '''
 38 |     opt, cmdline_args = load_opt_command(args)
 39 |     if cmdline_args.user_dir:
 40 |         absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
 41 |         opt['base_path'] = absolute_user_dir
 42 | 
 43 |     opt = init_distributed(opt)
 44 | 
 45 |     # META DATA
 46 |     pretrained_pth = os.path.join(opt['RESUME_FROM'])
 47 |     output_root = './output'
 48 |     image_pth = 'inference/images/street.jpg'
 49 | 
 50 |     model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
 51 | 
 52 |     t = []
 53 |     t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
 54 |     transform = transforms.Compose(t)
 55 | 
 56 |     thing_classes = ['car','person','traffic light', 'truck', 'motorcycle']
 57 |     stuff_classes = ['building','sky','street','tree','rock','sidewalk']
 58 |     thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))]
 59 |     stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(stuff_classes))]
 60 |     thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}
 61 |     stuff_dataset_id_to_contiguous_id = {x+len(thing_classes):x for x in range(len(stuff_classes))}
 62 | 
 63 |     MetadataCatalog.get("demo").set(
 64 |         thing_colors=thing_colors,
 65 |         thing_classes=thing_classes,
 66 |         thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
 67 |         stuff_colors=stuff_colors,
 68 |         stuff_classes=stuff_classes,
 69 |         stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id,
 70 |     )
 71 |     model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + stuff_classes + ["background"], is_eval=False)
 72 |     metadata = MetadataCatalog.get('demo')
 73 |     model.model.metadata = metadata
 74 |     model.model.sem_seg_head.num_classes = len(thing_classes + stuff_classes)
 75 | 
 76 |     with torch.no_grad():
 77 |         image_ori = Image.open(image_pth).convert("RGB")
 78 |         width = image_ori.size[0]
 79 |         height = image_ori.size[1]
 80 |         image = transform(image_ori)
 81 |         image = np.asarray(image)
 82 |         image_ori = np.asarray(image_ori)
 83 |         images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
 84 | 
 85 |         batch_inputs = [{'image': images, 'height': height, 'width': width}]
 86 |         outputs = model.forward(batch_inputs)
 87 |         visual = Visualizer(image_ori, metadata=metadata)
 88 | 
 89 |         pano_seg = outputs[-1]['panoptic_seg'][0]
 90 |         pano_seg_info = outputs[-1]['panoptic_seg'][1]
 91 | 
 92 |         for i in range(len(pano_seg_info)):
 93 |             if pano_seg_info[i]['category_id'] in metadata.thing_dataset_id_to_contiguous_id.keys():
 94 |                 pano_seg_info[i]['category_id'] = metadata.thing_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']]
 95 |             else:
 96 |                 pano_seg_info[i]['isthing'] = False
 97 |                 pano_seg_info[i]['category_id'] = metadata.stuff_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']]
 98 | 
 99 |         demo = visual.draw_panoptic_seg(pano_seg.cpu(), pano_seg_info) # rgb Image
100 | 
101 |         if not os.path.exists(output_root):
102 |             os.makedirs(output_root)
103 |         demo.save(os.path.join(output_root, 'pano.png'))
104 | 
105 | 
106 | if __name__ == "__main__":
107 |     main()
108 |     sys.exit(0)


--------------------------------------------------------------------------------
/inference/xdecoder/infer_refseg.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import os
 9 | import sys
10 | import json
11 | import logging
12 | 
13 | pth = '/'.join(sys.path[0].split('/')[:-1])
14 | sys.path.insert(0, pth)
15 | 
16 | from PIL import Image
17 | import numpy as np
18 | np.random.seed(27)
19 | 
20 | import torch
21 | from torchvision import transforms
22 | 
23 | from utils.arguments import load_opt_command
24 | 
25 | from detectron2.data import MetadataCatalog
26 | from detectron2.utils.colormap import random_color
27 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
28 | from modeling.BaseModel import BaseModel
29 | from modeling import build_model
30 | from utils.visualizer import Visualizer
31 | from utils.distributed import init_distributed
32 | 
33 | # logging.basicConfig(level = logging.INFO)
34 | logger = logging.getLogger(__name__)
35 | 
36 | 
37 | def main(args=None):
38 |     '''
39 |     Main execution point for PyLearn.
40 |     '''
41 |     opt, cmdline_args = load_opt_command(args)
42 |     if cmdline_args.user_dir:
43 |         absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
44 |         opt['base_path'] = absolute_user_dir
45 | 
46 |     opt = init_distributed(opt)
47 | 
48 |     # META DATA
49 |     pretrained_pth = os.path.join(opt['RESUME_FROM'])
50 |     output_root = './output'
51 |     image_pth = 'inference/images/fruit.jpg'
52 | 
53 |     text = [['The larger watermelon.'], ['The front white flower.'], ['White tea pot.'], ['Flower bunch.'], ['white vase.'], ['The left peach.'], ['The brown knife.']]
54 | 
55 |     model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
56 |     model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=False)
57 | 
58 |     t = []
59 |     t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
60 |     transform = transforms.Compose(t)
61 | 
62 |     metadata = MetadataCatalog.get('ade20k_panoptic_train')
63 |     model.model.metadata = metadata
64 | 
65 |     with torch.no_grad():
66 |         image_ori = Image.open(image_pth)
67 |         width = image_ori.size[0]
68 |         height = image_ori.size[1]
69 |         image = transform(image_ori)
70 |         image = np.asarray(image)
71 |         image_ori = np.asarray(image_ori)
72 |         images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
73 | 
74 |         batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': text}}]
75 |         outputs = model.model.evaluate_grounding(batch_inputs, None)
76 |         visual = Visualizer(image_ori, metadata=metadata)
77 | 
78 |         grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy()
79 |         for idx, mask in enumerate(grd_mask):
80 |             demo = visual.draw_binary_mask(mask, color=random_color(rgb=True, maximum=1).astype(np.int).tolist(), text=text[idx], alpha=0.3)
81 | 
82 |         output_folder = os.path.join(os.path.join(output_root))
83 |         if not os.path.exists(output_folder):
84 |             os.makedirs(output_folder)
85 |         demo.save(os.path.join(output_folder, 'refseg.png'))
86 | 
87 | 
88 | if __name__ == "__main__":
89 |     main()
90 |     sys.exit(0)


--------------------------------------------------------------------------------
/inference/xdecoder/infer_semseg.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | 
 8 | import os
 9 | import sys
10 | import logging
11 | 
12 | pth = '/'.join(sys.path[0].split('/')[:-1])
13 | sys.path.insert(0, pth)
14 | 
15 | from PIL import Image
16 | import numpy as np
17 | np.random.seed(1)
18 | 
19 | import torch
20 | from torchvision import transforms
21 | 
22 | from utils.arguments import load_opt_command
23 | 
24 | from detectron2.data import MetadataCatalog
25 | from detectron2.utils.colormap import random_color
26 | from modeling.BaseModel import BaseModel
27 | from modeling import build_model
28 | from utils.visualizer import Visualizer
29 | from utils.distributed import init_distributed
30 | 
31 | logger = logging.getLogger(__name__)
32 | 
33 | 
34 | def main(args=None):
35 |     '''
36 |     Main execution point for PyLearn.
37 |     '''
38 |     opt, cmdline_args = load_opt_command(args)
39 |     if cmdline_args.user_dir:
40 |         absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
41 |         opt['base_path'] = absolute_user_dir
42 |     opt = init_distributed(opt)
43 | 
44 |     # META DATA
45 |     pretrained_pth = os.path.join(opt['RESUME_FROM'])
46 |     output_root = './output'
47 |     image_pth = 'inference/images/animals.png'
48 | 
49 |     model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
50 | 
51 |     t = []
52 |     t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
53 |     transform = transforms.Compose(t)
54 | 
55 |     stuff_classes = ['zebra','antelope','giraffe','ostrich','sky','water','grass','sand','tree']
56 |     stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(stuff_classes))]
57 |     stuff_dataset_id_to_contiguous_id = {x:x for x in range(len(stuff_classes))}
58 | 
59 |     MetadataCatalog.get("demo").set(
60 |         stuff_colors=stuff_colors,
61 |         stuff_classes=stuff_classes,
62 |         stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id,
63 |     )
64 |     model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(stuff_classes + ["background"], is_eval=True)
65 |     metadata = MetadataCatalog.get('demo')
66 |     model.model.metadata = metadata
67 |     model.model.sem_seg_head.num_classes = len(stuff_classes)
68 | 
69 |     with torch.no_grad():
70 |         image_ori = Image.open(image_pth).convert("RGB")
71 |         width = image_ori.size[0]
72 |         height = image_ori.size[1]
73 |         image = transform(image_ori)
74 |         image = np.asarray(image)
75 |         image_ori = np.asarray(image_ori)
76 |         images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
77 | 
78 |         batch_inputs = [{'image': images, 'height': height, 'width': width}]
79 |         outputs = model.forward(batch_inputs)
80 |         visual = Visualizer(image_ori, metadata=metadata)
81 | 
82 |         sem_seg = outputs[-1]['sem_seg'].max(0)[1]
83 |         demo = visual.draw_sem_seg(sem_seg.cpu(), alpha=0.5) # rgb Image
84 | 
85 |         if not os.path.exists(output_root):
86 |             os.makedirs(output_root)
87 |         demo.save(os.path.join(output_root, 'sem.png'))
88 | 
89 | 
90 | if __name__ == "__main__":
91 |     main()
92 |     sys.exit(0)


--------------------------------------------------------------------------------
/modeling/BaseModel.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import logging
 3 | 
 4 | import torch
 5 | import torch.nn as nn
 6 | 
 7 | from utils.model import align_and_update_state_dicts
 8 | 
 9 | logger = logging.getLogger(__name__)
10 | 
11 | 
12 | class BaseModel(nn.Module):
13 |     def __init__(self, opt, module: nn.Module):
14 |         super(BaseModel, self).__init__()
15 |         self.opt = opt
16 |         self.model = module
17 | 
18 |     def forward(self, *inputs, **kwargs):
19 |         outputs = self.model(*inputs, **kwargs)
20 |         return outputs
21 | 
22 |     def save_pretrained(self, save_dir):
23 |         torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
24 | 
25 |     def from_pretrained(self, load_dir):
26 |         state_dict = torch.load(load_dir, map_location=self.opt['device'])
27 |         state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
28 |         self.model.load_state_dict(state_dict, strict=False)
29 |         return self


--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .architectures import build_model


--------------------------------------------------------------------------------
/modeling/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | from .xdecoder_model import *
2 | from .seem_model_v0 import *
3 | from .seem_model_v1 import *
4 | from .seem_model_demo import *
5 | from .build import build_model


--------------------------------------------------------------------------------
/modeling/architectures/build.py:
--------------------------------------------------------------------------------
 1 | _model_entrypoints = {}
 2 | 
 3 | 
 4 | def build_model(config, **kwargs):
 5 |     model_name = config['MODEL']['NAME']
 6 | 
 7 |     if not is_model(model_name):
 8 |         raise ValueError(f'Unkown model: {model_name}')
 9 | 
10 |     return model_entrypoints(model_name)(config, **kwargs)
11 | 
12 | def register_model(fn):
13 |     module_name_split = fn.__module__.split('.')
14 |     model_name = module_name_split[-1]
15 |     _model_entrypoints[model_name] = fn
16 |     return fn
17 | 
18 | def model_entrypoints(model_name):
19 |     return _model_entrypoints[model_name]
20 | 
21 | def is_model(model_name):
22 |     return model_name in _model_entrypoints


--------------------------------------------------------------------------------
/modeling/body/__init__.py:
--------------------------------------------------------------------------------
 1 | from .xdecoder_head import *
 2 | from .build import *
 3 | 
 4 | def build_xdecoder_head(config, *args, **kwargs):
 5 |     model_name = config['MODEL']['HEAD']
 6 |     if not is_model(model_name):
 7 |         raise ValueError(f'Unkown model: {model_name}')
 8 | 
 9 |     body = model_entrypoints(model_name)(config, *args, **kwargs)
10 |     return body


--------------------------------------------------------------------------------
/modeling/body/build.py:
--------------------------------------------------------------------------------
 1 | _model_entrypoints = {}
 2 | 
 3 | def register_body(fn):
 4 |     module_name_split = fn.__module__.split('.')
 5 |     model_name = module_name_split[-1]
 6 |     _model_entrypoints[model_name] = fn
 7 |     return fn
 8 | 
 9 | def model_entrypoints(model_name):
10 |     return _model_entrypoints[model_name]
11 | 
12 | def is_model(model_name):
13 |     return model_name in _model_entrypoints


--------------------------------------------------------------------------------
/modeling/body/xdecoder_head.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | # Copyright (c) Facebook, Inc. and its affiliates.
  8 | from typing import Dict
  9 | 
 10 | from torch import nn
 11 | 
 12 | from detectron2.layers import ShapeSpec
 13 | 
 14 | from .build import register_body
 15 | from ..vision.encoder import build_encoder
 16 | from ..interface import build_decoder
 17 | from ..utils import configurable
 18 | 
 19 | 
 20 | class XdecoderHead(nn.Module):
 21 | 
 22 |     @configurable
 23 |     def __init__(
 24 |         self,
 25 |         input_shape: Dict[str, ShapeSpec],
 26 |         *,
 27 |         num_classes: int,
 28 |         pixel_decoder: nn.Module,
 29 |         loss_weight: float = 1.0,
 30 |         ignore_value: int = -1,
 31 |         # extra parameters
 32 |         transformer_predictor: nn.Module,
 33 |         transformer_in_feature: str,
 34 |     ):
 35 |         """
 36 |         NOTE: this interface is experimental.
 37 |         Args:
 38 |             input_shape: shapes (channels and stride) of the input features
 39 |             num_classes: number of classes to predict
 40 |             pixel_decoder: the pixel decoder module
 41 |             loss_weight: loss weight
 42 |             ignore_value: category id to be ignored during training.
 43 |             transformer_predictor: the transformer decoder that makes prediction
 44 |             transformer_in_feature: input feature name to the transformer_predictor
 45 |         """
 46 |         super().__init__()
 47 | 
 48 |         input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
 49 |         self.in_features = [k for k, v in input_shape]
 50 |         feature_strides = [v.stride for k, v in input_shape]
 51 |         feature_channels = [v.channels for k, v in input_shape]
 52 | 
 53 |         self.ignore_value = ignore_value
 54 |         self.common_stride = 4
 55 |         self.loss_weight = loss_weight
 56 | 
 57 |         self.pixel_decoder = pixel_decoder
 58 |         self.predictor = transformer_predictor
 59 |         self.transformer_in_feature = transformer_in_feature
 60 | 
 61 |         self.num_classes = num_classes
 62 | 
 63 |     @classmethod
 64 |     def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
 65 | 
 66 |         in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
 67 |         enc_cfg = cfg['MODEL']['ENCODER']
 68 |         dec_cfg = cfg['MODEL']['DECODER']
 69 | 
 70 |         # figure out in_channels to transformer predictor
 71 |         if in_features_type == "transformer_encoder":
 72 |             transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
 73 |         elif in_features_type == "pixel_embedding":
 74 |             transformer_predictor_in_channels = enc_cfg['MASK_DIM']
 75 |         elif in_features_type == "multi_scale_pixel_decoder":
 76 |             transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
 77 |         else:
 78 |             transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
 79 | 
 80 |         return {
 81 |             "input_shape": {
 82 |                 k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
 83 |             },
 84 |             "ignore_value": enc_cfg['IGNORE_VALUE'],
 85 |             "num_classes": enc_cfg.get('NUM_CLASSES', None),
 86 |             "pixel_decoder": build_encoder(cfg, input_shape),
 87 |             "loss_weight": enc_cfg['LOSS_WEIGHT'],
 88 |             "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
 89 |             "transformer_predictor": build_decoder(
 90 |                 cfg,
 91 |                 transformer_predictor_in_channels,
 92 |                 lang_encoder,
 93 |                 mask_classification=True,
 94 |                 extra=extra,
 95 |             ),
 96 |         }
 97 | 
 98 |     def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
 99 |         return self.layers(features, mask, target_queries, target_vlp, task, extra)
100 | 
101 |     def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
102 |         mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
103 |         
104 |         if self.transformer_in_feature == "multi_scale_pixel_decoder":
105 |             predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
106 |         else:
107 |             if self.transformer_in_feature == "transformer_encoder":
108 |                 assert (
109 |                     transformer_encoder_features is not None
110 |                 ), "Please use the TransformerEncoderPixelDecoder."
111 |                 predictions = self.predictor(transformer_encoder_features, mask_features, mask)
112 |             elif self.transformer_in_feature == "pixel_embedding":
113 |                 predictions = self.predictor(mask_features, mask_features, mask)
114 |             else:
115 |                 predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
116 |         return predictions
117 | 
118 | 
119 | @register_body
120 | def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
121 |     return XdecoderHead(cfg, input_shape, lang_encoder, extra)


--------------------------------------------------------------------------------
/modeling/interface/__init__.py:
--------------------------------------------------------------------------------
 1 | from .xdecoder import *
 2 | from .seem_v0 import *
 3 | from .seem_v1 import *
 4 | from .seem_demo import *
 5 | from .build import *
 6 | 
 7 | def build_decoder(config, *args, **kwargs):
 8 |     model_name = config['MODEL']['DECODER']['NAME']
 9 | 
10 |     if not is_model(model_name):
11 |         raise ValueError(f'Unkown model: {model_name}')
12 | 
13 |     return model_entrypoints(model_name)(config, *args, **kwargs)


--------------------------------------------------------------------------------
/modeling/interface/build.py:
--------------------------------------------------------------------------------
 1 | _model_entrypoints = {}
 2 | 
 3 | 
 4 | def register_decoder(fn):
 5 |     module_name_split = fn.__module__.split('.')
 6 |     model_name = module_name_split[-1]
 7 |     _model_entrypoints[model_name] = fn
 8 |     return fn
 9 | 
10 | def model_entrypoints(model_name):
11 |     return _model_entrypoints[model_name]
12 | 
13 | def is_model(model_name):
14 |     return model_name in _model_entrypoints


--------------------------------------------------------------------------------
/modeling/interface/prototype/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/modeling/interface/prototype/__init__.py


--------------------------------------------------------------------------------
/modeling/language/LangEncoder/__init__.py:
--------------------------------------------------------------------------------
 1 | from transformers import CLIPTokenizer, CLIPTokenizerFast
 2 | from transformers import AutoTokenizer
 3 | 
 4 | from .transformer import *
 5 | from .build import *
 6 | 
 7 | 
 8 | def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
 9 |     model_name = config_encoder['NAME']
10 | 
11 |     if not is_lang_encoder(model_name):
12 |         raise ValueError(f'Unkown model: {model_name}')
13 | 
14 |     return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
15 | 
16 | def build_tokenizer(config_encoder):
17 |     tokenizer = None
18 |     os.environ['TOKENIZERS_PARALLELISM'] = 'true'
19 |     if config_encoder['TOKENIZER'] == 'clip':
20 |         pretrained_tokenizer = config_encoder.get(
21 |             'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
22 |         )
23 |         tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
24 |         tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
25 |     elif config_encoder['TOKENIZER'] == 'clip-fast':
26 |         pretrained_tokenizer = config_encoder.get(
27 |             'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
28 |         )
29 |         tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
30 |     else:
31 |         tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
32 | 
33 |     return tokenizer


--------------------------------------------------------------------------------
/modeling/language/LangEncoder/build.py:
--------------------------------------------------------------------------------
 1 | _lang_encoders = {}
 2 | 
 3 | 
 4 | def register_lang_encoder(fn):
 5 |     module_name_split = fn.__module__.split('.')
 6 |     model_name = module_name_split[-1]
 7 | 
 8 |     _lang_encoders[model_name] = fn
 9 | 
10 |     return fn
11 | 
12 | def lang_encoders(model_name):
13 |     return _lang_encoders[model_name]
14 | 
15 | def is_lang_encoder(model_name):
16 |     return model_name in _lang_encoders
17 | 


--------------------------------------------------------------------------------
/modeling/language/__init__.py:
--------------------------------------------------------------------------------
 1 | from .vlpencoder import *
 2 | from .build import *
 3 | 
 4 | def build_language_encoder(config, **kwargs):
 5 |     model_name = config['MODEL']['TEXT']['ARCH']
 6 | 
 7 |     if not is_model(model_name):
 8 |         raise ValueError(f'Unkown model: {model_name}')
 9 | 
10 |     return model_entrypoints(model_name)(config, **kwargs)


--------------------------------------------------------------------------------
/modeling/language/build.py:
--------------------------------------------------------------------------------
 1 | _model_entrypoints = {}
 2 | 
 3 | 
 4 | def register_model(fn):
 5 |     module_name_split = fn.__module__.split('.')
 6 |     model_name = module_name_split[-1]
 7 |     _model_entrypoints[model_name] = fn
 8 |     return fn
 9 | 
10 | def model_entrypoints(model_name):
11 |     return _model_entrypoints[model_name]
12 | 
13 | def is_model(model_name):
14 |     return model_name in _model_entrypoints


--------------------------------------------------------------------------------
/modeling/language/misc.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | 
 3 | import torch
 4 | import nltk
 5 | import numpy as np
 6 | 
 7 | from utils.constants import IMAGENET_DEFAULT_TEMPLATES
 8 | 
 9 | nltk.download('punkt', quiet=True)
10 | nltk.download('averaged_perceptron_tagger', quiet=True)
11 | 
12 | def get_tag(tokenized, tags):
13 |     if not isinstance(tags, (list, tuple)):
14 |         tags = [tags]
15 |     ret = []
16 |     for (word, pos) in nltk.pos_tag(tokenized):
17 |         for tag in tags:
18 |             if pos == tag:
19 |                 ret.append(word)
20 |     return ret
21 | 
22 | def get_noun_phrase(tokenized):
23 |     # Taken from Su Nam Kim Paper...
24 |     grammar = r"""
25 |         NBAR:
26 |             {<NN.*|JJ>*<NN.*>}  # Nouns and Adjectives, terminated with Nouns
27 | 
28 |         NP:
29 |             {<NBAR>}
30 |             {<NBAR><IN><NBAR>}  # Above, connected with in/of/etc...
31 |     """
32 |     chunker = nltk.RegexpParser(grammar)
33 | 
34 |     chunked = chunker.parse(nltk.pos_tag(tokenized))
35 |     continuous_chunk = []
36 |     current_chunk = []
37 | 
38 |     for subtree in chunked:
39 |         if isinstance(subtree, nltk.Tree):
40 |             current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
41 |         elif current_chunk:
42 |             named_entity = ' '.join(current_chunk)
43 |             if named_entity not in continuous_chunk:
44 |                 continuous_chunk.append(named_entity)
45 |                 current_chunk = []
46 |         else:
47 |             continue
48 | 
49 |     return continuous_chunk
50 | 
51 | def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
52 |     tokenized = nltk.word_tokenize(text)
53 |     
54 |     if random.random() >= phrase_prob:
55 |         nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
56 |     else:
57 |         nouns = get_noun_phrase(tokenized)
58 | 
59 | 
60 |     prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
61 |     
62 |     if append_text:
63 |         prompt_texts += [text]
64 |         nouns += [text]
65 |     
66 |     return prompt_texts, nouns


--------------------------------------------------------------------------------
/modeling/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .point_features import *
2 | from .position_encoding import *
3 | from .postprocessing import *
4 | from .attention import *
5 | from .criterion import *
6 | from .matcher import *


--------------------------------------------------------------------------------
/modeling/modules/position_encoding.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates.
 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
 3 | """
 4 | Various positional encodings for the transformer.
 5 | """
 6 | import math
 7 | 
 8 | import torch
 9 | from torch import nn
10 | 
11 | 
12 | class PositionEmbeddingSine(nn.Module):
13 |     """
14 |     This is a more standard version of the position embedding, very similar to the one
15 |     used by the Attention is all you need paper, generalized to work on images.
16 |     """
17 | 
18 |     def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19 |         super().__init__()
20 |         self.num_pos_feats = num_pos_feats
21 |         self.temperature = temperature
22 |         self.normalize = normalize
23 |         if scale is not None and normalize is False:
24 |             raise ValueError("normalize should be True if scale is passed")
25 |         if scale is None:
26 |             scale = 2 * math.pi
27 |         self.scale = scale
28 | 
29 |     def forward(self, x, mask=None):
30 |         if mask is None:
31 |             mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32 |         not_mask = ~mask
33 |         y_embed = not_mask.cumsum(1, dtype=x.dtype)
34 |         x_embed = not_mask.cumsum(2, dtype=x.dtype)
35 |         if self.normalize:
36 |             eps = 1e-6
37 |             y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38 |             x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39 | 
40 |         dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
41 |         dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42 | 
43 |         pos_x = x_embed[:, :, :, None] / dim_t
44 |         pos_y = y_embed[:, :, :, None] / dim_t
45 |         pos_x = torch.stack(
46 |             (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
47 |         ).flatten(3)
48 |         pos_y = torch.stack(
49 |             (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
50 |         ).flatten(3)
51 |         pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52 |         return pos
53 |     
54 |     def __repr__(self, _repr_indent=4):
55 |         head = "Positional encoding " + self.__class__.__name__
56 |         body = [
57 |             "num_pos_feats: {}".format(self.num_pos_feats),
58 |             "temperature: {}".format(self.temperature),
59 |             "normalize: {}".format(self.normalize),
60 |             "scale: {}".format(self.scale),
61 |         ]
62 |         # _repr_indent = 4
63 |         lines = [head] + [" " * _repr_indent + line for line in body]
64 |         return "\n".join(lines)
65 | 


--------------------------------------------------------------------------------
/modeling/modules/postprocessing.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Facebook, Inc. and its affiliates.
  2 | import torch
  3 | from torch.nn import functional as F
  4 | 
  5 | from detectron2.structures import Instances, ROIMasks
  6 | 
  7 | 
  8 | # perhaps should rename to "resize_instance"
  9 | def detector_postprocess(
 10 |     results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
 11 | ):
 12 |     """
 13 |     Resize the output instances.
 14 |     The input images are often resized when entering an object detector.
 15 |     As a result, we often need the outputs of the detector in a different
 16 |     resolution from its inputs.
 17 | 
 18 |     This function will resize the raw outputs of an R-CNN detector
 19 |     to produce outputs according to the desired output resolution.
 20 | 
 21 |     Args:
 22 |         results (Instances): the raw outputs from the detector.
 23 |             `results.image_size` contains the input image resolution the detector sees.
 24 |             This object might be modified in-place.
 25 |         output_height, output_width: the desired output resolution.
 26 | 
 27 |     Returns:
 28 |         Instances: the resized output from the model, based on the output resolution
 29 |     """
 30 |     if isinstance(output_width, torch.Tensor):
 31 |         # This shape might (but not necessarily) be tensors during tracing.
 32 |         # Converts integer tensors to float temporaries to ensure true
 33 |         # division is performed when computing scale_x and scale_y.
 34 |         output_width_tmp = output_width.float()
 35 |         output_height_tmp = output_height.float()
 36 |         new_size = torch.stack([output_height, output_width])
 37 |     else:
 38 |         new_size = (output_height, output_width)
 39 |         output_width_tmp = output_width
 40 |         output_height_tmp = output_height
 41 | 
 42 |     scale_x, scale_y = (
 43 |         output_width_tmp / results.image_size[1],
 44 |         output_height_tmp / results.image_size[0],
 45 |     )
 46 |     results = Instances(new_size, **results.get_fields())
 47 | 
 48 |     if results.has("pred_boxes"):
 49 |         output_boxes = results.pred_boxes
 50 |     elif results.has("proposal_boxes"):
 51 |         output_boxes = results.proposal_boxes
 52 |     else:
 53 |         output_boxes = None
 54 |     assert output_boxes is not None, "Predictions must contain boxes!"
 55 | 
 56 |     output_boxes.scale(scale_x, scale_y)
 57 |     output_boxes.clip(results.image_size)
 58 | 
 59 |     results = results[output_boxes.nonempty()]
 60 | 
 61 |     if results.has("pred_masks"):
 62 |         if isinstance(results.pred_masks, ROIMasks):
 63 |             roi_masks = results.pred_masks
 64 |         else:
 65 |             # pred_masks is a tensor of shape (N, 1, M, M)
 66 |             roi_masks = ROIMasks(results.pred_masks[:, 0, :, :])
 67 |         results.pred_masks = roi_masks.to_bitmasks(
 68 |             results.pred_boxes, output_height, output_width, mask_threshold
 69 |         ).tensor  # TODO return ROIMasks/BitMask object in the future
 70 | 
 71 |     if results.has("pred_keypoints"):
 72 |         results.pred_keypoints[:, :, 0] *= scale_x
 73 |         results.pred_keypoints[:, :, 1] *= scale_y
 74 | 
 75 |     return results
 76 | 
 77 | def bbox_postprocess(result, input_size, img_size, output_height, output_width):
 78 |     """
 79 |     result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h]
 80 |     """
 81 |     if result is None:
 82 |         return None
 83 |     
 84 |     scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device)
 85 |     result = result.sigmoid() * scale
 86 |     x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2
 87 |     h,w = img_size
 88 | 
 89 |     x1 = x1.clamp(min=0, max=w)
 90 |     y1 = y1.clamp(min=0, max=h)
 91 |     x2 = x2.clamp(min=0, max=w)
 92 |     y2 = y2.clamp(min=0, max=h)
 93 | 
 94 |     box = torch.stack([x1,y1,x2,y2]).permute(1,0)
 95 |     scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device)
 96 |     box = box*scale
 97 |     return box
 98 | 
 99 | def sem_seg_postprocess(result, img_size, output_height, output_width):
100 |     """
101 |     Return semantic segmentation predictions in the original resolution.
102 | 
103 |     The input images are often resized when entering semantic segmentor. Moreover, in same
104 |     cases, they also padded inside segmentor to be divisible by maximum network stride.
105 |     As a result, we often need the predictions of the segmentor in a different
106 |     resolution from its inputs.
107 | 
108 |     Args:
109 |         result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W),
110 |             where C is the number of classes, and H, W are the height and width of the prediction.
111 |         img_size (tuple): image size that segmentor is taking as input.
112 |         output_height, output_width: the desired output resolution.
113 | 
114 |     Returns:
115 |         semantic segmentation prediction (Tensor): A tensor of the shape
116 |             (C, output_height, output_width) that contains per-pixel soft predictions.
117 |     """
118 |     result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1)
119 |     result = F.interpolate(
120 |         result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True
121 |     )[0]
122 |     return result
123 | 


--------------------------------------------------------------------------------
/modeling/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import *
2 | from .misc import *
3 | from .interactive import *
4 | from .attention import *


--------------------------------------------------------------------------------
/modeling/utils/box_ops.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
 2 | """
 3 | Utilities for bounding box manipulation and GIoU.
 4 | """
 5 | import torch
 6 | from torchvision.ops.boxes import box_area
 7 | 
 8 | 
 9 | def box_cxcywh_to_xyxy(x):
10 |     x_c, y_c, w, h = x.unbind(-1)
11 |     b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 |          (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 |     return torch.stack(b, dim=-1)
14 | 
15 | 
16 | def box_xyxy_to_cxcywh(x):
17 |     x0, y0, x1, y1 = x.unbind(-1)
18 |     b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 |          (x1 - x0), (y1 - y0)]
20 |     return torch.stack(b, dim=-1)
21 | 
22 | def box_xywh_to_xyxy(x):
23 |     x0, y0, x1, y1 = x.unbind(-1)
24 |     b = [x0, y0, (x0 + x1), (y0 + y1)]
25 |     return torch.stack(b, dim=-1)
26 | 
27 | 
28 | # modified from torchvision to also return the union
29 | def box_iou(boxes1, boxes2):
30 |     area1 = box_area(boxes1)
31 |     area2 = box_area(boxes2)
32 | 
33 |     lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
34 |     rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
35 | 
36 |     wh = (rb - lt).clamp(min=0)  # [N,M,2]
37 |     inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
38 | 
39 |     union = area1[:, None] + area2 - inter
40 | 
41 |     iou = inter / (union+1e-6)
42 |     return iou, union
43 | 
44 | 
45 | def generalized_box_iou(boxes1, boxes2):
46 |     """
47 |     Generalized IoU from https://giou.stanford.edu/
48 | 
49 |     The boxes should be in [x0, y0, x1, y1] format
50 | 
51 |     Returns a [N, M] pairwise matrix, where N = len(boxes1)
52 |     and M = len(boxes2)
53 |     """
54 |     # degenerate boxes gives inf / nan results
55 |     # so do an early check
56 |     assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
57 |     assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
58 |     iou, union = box_iou(boxes1, boxes2)
59 | 
60 |     lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
61 |     rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
62 | 
63 |     wh = (rb - lt).clamp(min=0)  # [N,M,2]
64 |     area = wh[:, :, 0] * wh[:, :, 1]
65 | 
66 |     return iou - (area - union) / (area+1e-6)
67 | 
68 | 
69 | def masks_to_boxes(masks):
70 |     """Compute the bounding boxes around the provided masks
71 | 
72 |     The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
73 | 
74 |     Returns a [N, 4] tensors, with the boxes in xyxy format
75 |     """
76 |     if masks.numel() == 0:
77 |         return torch.zeros((0, 4), device=masks.device)
78 | 
79 |     h, w = masks.shape[-2:]
80 | 
81 |     y = torch.arange(0, h, dtype=torch.float)
82 |     x = torch.arange(0, w, dtype=torch.float)
83 |     y, x = torch.meshgrid(y, x)
84 | 
85 |     x_mask = (masks * x.unsqueeze(0))
86 |     x_max = x_mask.flatten(1).max(-1)[0]
87 |     x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
88 | 
89 |     y_mask = (masks * y.unsqueeze(0))
90 |     y_max = y_mask.flatten(1).max(-1)[0]
91 |     y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
92 | 
93 |     return torch.stack([x_min, y_min, x_max, y_max], 1)


--------------------------------------------------------------------------------
/modeling/utils/interactive.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import copy
 3 | import math
 4 | 
 5 | import torch
 6 | from torch import nn, Tensor
 7 | import torch.nn.functional as F
 8 | 
 9 | 
10 | def rand_sample(x, divisor, max_len):
11 |     # non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
12 |     if len(x.nonzero()) == 0:
13 |         return x.nonzero().t()
14 | 
15 |     non_zero_point_index = (x.nonzero()/divisor).t()
16 |     mask_ids = non_zero_point_index[0].unique().long()
17 | 
18 |     # compute probability for each samle
19 |     probs = torch.zeros_like(non_zero_point_index[0])
20 |     for idx in mask_ids:
21 |         prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum()))
22 |         probs[non_zero_point_index[0]==idx] = prob
23 |     
24 |     indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0]
25 |     non_zero_point_index = non_zero_point_index[:,indices]
26 |     return non_zero_point_index # [n, 512]
27 | 
28 | def rand_sample_plain(x, max_len):
29 |     if x.shape[1] <= max_len:
30 |         return x
31 |     else:
32 |         rand_idx = torch.randperm(x.shape[1])[:max_len]
33 |         return x[:,rand_idx]
34 | 
35 | def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed):
36 |     src = []
37 |     pos = []
38 |     size_list = []
39 | 
40 |     # disable mask, it does not affect performance
41 |     for i in range(num_feature_levels):
42 |         size_list.append(x[i].shape[-2:])
43 |         pos.append(pe_layer(x[i], None).flatten(2))
44 |         src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None])
45 | 
46 |         # flatten NxCxHxW to HWxNxC
47 |         pos[-1] = pos[-1].permute(2, 0, 1)
48 |         src[-1] = src[-1].permute(2, 0, 1)
49 |     return src, pos, size_list


--------------------------------------------------------------------------------
/modeling/vision/backbone/__init__.py:
--------------------------------------------------------------------------------
 1 | from .focal import *
 2 | from .focal_dw import *
 3 | from .davit import *
 4 | from .vit import *
 5 | from .backbone import *
 6 | from .build import *
 7 | 
 8 | 
 9 | def build_backbone(config, **kwargs):
10 |     model_name = config['MODEL']['BACKBONE']['NAME']
11 |     if not is_model(model_name):
12 |         raise ValueError(f'Unkown model: {model_name}')
13 | 
14 |     return model_entrypoints(model_name)(config, **kwargs)


--------------------------------------------------------------------------------
/modeling/vision/backbone/backbone.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates.
 2 | import torch.nn as nn
 3 | 
 4 | from detectron2.modeling import ShapeSpec
 5 | 
 6 | # from ..layers import ShapeSpec
 7 | 
 8 | __all__ = ["Backbone"]
 9 | 
10 | 
11 | class Backbone(nn.Module):
12 |     """
13 |     Abstract base class for network backbones.
14 |     """
15 | 
16 |     def __init__(self):
17 |         """
18 |         The `__init__` method of any subclass can specify its own set of arguments.
19 |         """
20 |         super().__init__()
21 | 
22 |     def forward(self):
23 |         """
24 |         Subclasses must override this method, but adhere to the same return type.
25 | 
26 |         Returns:
27 |             dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
28 |         """
29 |         pass
30 | 
31 |     @property
32 |     def size_divisibility(self) -> int:
33 |         """
34 |         Some backbones require the input height and width to be divisible by a
35 |         specific integer. This is typically true for encoder / decoder type networks
36 |         with lateral connection (e.g., FPN) for which feature maps need to match
37 |         dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
38 |         input size divisibility is required.
39 |         """
40 |         return 0
41 | 
42 |     def output_shape(self):
43 |         """
44 |         Returns:
45 |             dict[str->ShapeSpec]
46 |         """
47 |         # this is a backward-compatible default
48 |         return {
49 |             name: ShapeSpec(
50 |                 channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
51 |             )
52 |             for name in self._out_features
53 |         }
54 | 


--------------------------------------------------------------------------------
/modeling/vision/backbone/build.py:
--------------------------------------------------------------------------------
 1 | _model_entrypoints = {}
 2 | 
 3 | 
 4 | def register_backbone(fn):
 5 |     module_name_split = fn.__module__.split('.')
 6 |     model_name = module_name_split[-1]
 7 |     _model_entrypoints[model_name] = fn
 8 |     return fn
 9 | 
10 | def model_entrypoints(model_name):
11 |     return _model_entrypoints[model_name]
12 | 
13 | def is_model(model_name):
14 |     return model_name in _model_entrypoints


--------------------------------------------------------------------------------
/modeling/vision/backbone/common.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | # All rights reserved.
 3 | 
 4 | # This source code is licensed under the license found in the
 5 | # LICENSE file in the root directory of this source tree.
 6 | 
 7 | import torch
 8 | import torch.nn as nn
 9 | 
10 | from typing import Type
11 | 
12 | 
13 | class MLPBlock(nn.Module):
14 |     def __init__(
15 |         self,
16 |         embedding_dim: int,
17 |         mlp_dim: int,
18 |         act: Type[nn.Module] = nn.GELU,
19 |     ) -> None:
20 |         super().__init__()
21 |         self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 |         self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 |         self.act = act()
24 | 
25 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
26 |         return self.lin2(self.act(self.lin1(x)))
27 | 
28 | 
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa
31 | class LayerNorm2d(nn.Module):
32 |     def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 |         super().__init__()
34 |         self.weight = nn.Parameter(torch.ones(num_channels))
35 |         self.bias = nn.Parameter(torch.zeros(num_channels))
36 |         self.eps = eps
37 | 
38 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
39 |         u = x.mean(1, keepdim=True)
40 |         s = (x - u).pow(2).mean(1, keepdim=True)
41 |         x = (x - u) / torch.sqrt(s + self.eps)
42 |         x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 |         return x
44 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/__init__.py:
--------------------------------------------------------------------------------
 1 | from .transformer_encoder_fpn import *
 2 | try:
 3 |     from .transformer_encoder_deform import *
 4 | except:
 5 |     print('Deformable Transformer Encoder is not available.')
 6 | from .build import *
 7 | 
 8 | 
 9 | def build_encoder(config, *args, **kwargs):
10 |     model_name = config['MODEL']['ENCODER']['NAME']
11 | 
12 |     if not is_model(model_name):
13 |         raise ValueError(f'Unkown model: {model_name}')
14 | 
15 |     return model_entrypoints(model_name)(config, *args, **kwargs)


--------------------------------------------------------------------------------
/modeling/vision/encoder/build.py:
--------------------------------------------------------------------------------
 1 | _model_entrypoints = {}
 2 | 
 3 | 
 4 | def register_encoder(fn):
 5 |     module_name_split = fn.__module__.split('.')
 6 |     model_name = module_name_split[-1]
 7 |     _model_entrypoints[model_name] = fn
 8 |     return fn
 9 | 
10 | def model_entrypoints(model_name):
11 |     return _model_entrypoints[model_name]
12 | 
13 | def is_model(model_name):
14 |     return model_name in _model_entrypoints


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/functions/__init__.py:
--------------------------------------------------------------------------------
 1 | # ------------------------------------------------------------------------------------------------
 2 | # Deformable DETR
 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 5 | # ------------------------------------------------------------------------------------------------
 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 7 | # ------------------------------------------------------------------------------------------------
 8 | 
 9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 | 
12 | from .ms_deform_attn_func import MSDeformAttnFunction
13 | 
14 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
 1 | # ------------------------------------------------------------------------------------------------
 2 | # Deformable DETR
 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 5 | # ------------------------------------------------------------------------------------------------
 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 7 | # ------------------------------------------------------------------------------------------------
 8 | 
 9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 | 
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 | 
16 | import torch
17 | import torch.nn.functional as F
18 | from torch.autograd import Function
19 | from torch.autograd.function import once_differentiable
20 | 
21 | try:
22 |     import MultiScaleDeformableAttention as MSDA
23 | except ModuleNotFoundError as e:
24 |     info_string = (
25 |         "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26 |         "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
27 |         "\t`sh make.sh`\n"
28 |     )
29 |     raise ModuleNotFoundError(info_string)
30 | 
31 | 
32 | class MSDeformAttnFunction(Function):
33 |     @staticmethod
34 |     def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35 |         ctx.im2col_step = im2col_step
36 |         output = MSDA.ms_deform_attn_forward(
37 |             value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38 |         ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39 |         return output
40 | 
41 |     @staticmethod
42 |     @once_differentiable
43 |     def backward(ctx, grad_output):
44 |         value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45 |         grad_value, grad_sampling_loc, grad_attn_weight = \
46 |             MSDA.ms_deform_attn_backward(
47 |                 value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48 | 
49 |         return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50 | 
51 | 
52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53 |     # for debug and test only,
54 |     # need to use cuda version instead
55 |     N_, S_, M_, D_ = value.shape
56 |     _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57 |     value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58 |     sampling_grids = 2 * sampling_locations - 1
59 |     sampling_value_list = []
60 |     for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61 |         # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62 |         value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63 |         # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64 |         sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65 |         # N_*M_, D_, Lq_, P_
66 |         sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67 |                                           mode='bilinear', padding_mode='zeros', align_corners=False)
68 |         sampling_value_list.append(sampling_value_l_)
69 |     # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70 |     attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71 |     output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72 |     return output.transpose(1, 2).contiguous()
73 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/make.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | # ------------------------------------------------------------------------------------------------
 3 | # Deformable DETR
 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 6 | # ------------------------------------------------------------------------------------------------
 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 8 | # ------------------------------------------------------------------------------------------------
 9 | 
10 | # Copyright (c) Facebook, Inc. and its affiliates.
11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12 | 
13 | python setup.py build install
14 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/modules/__init__.py:
--------------------------------------------------------------------------------
 1 | # ------------------------------------------------------------------------------------------------
 2 | # Deformable DETR
 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 5 | # ------------------------------------------------------------------------------------------------
 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 7 | # ------------------------------------------------------------------------------------------------
 8 | 
 9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 | 
12 | from .ms_deform_attn import MSDeformAttn
13 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/setup.py:
--------------------------------------------------------------------------------
 1 | # ------------------------------------------------------------------------------------------------
 2 | # Deformable DETR
 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 5 | # ------------------------------------------------------------------------------------------------
 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 7 | # ------------------------------------------------------------------------------------------------
 8 | 
 9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 | 
12 | import os
13 | import glob
14 | 
15 | import torch
16 | 
17 | from torch.utils.cpp_extension import CUDA_HOME
18 | from torch.utils.cpp_extension import CppExtension
19 | from torch.utils.cpp_extension import CUDAExtension
20 | 
21 | from setuptools import find_packages
22 | from setuptools import setup
23 | 
24 | requirements = ["torch", "torchvision"]
25 | 
26 | def get_extensions():
27 |     this_dir = os.path.dirname(os.path.abspath(__file__))
28 |     extensions_dir = os.path.join(this_dir, "src")
29 | 
30 |     main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31 |     source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32 |     source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33 | 
34 |     sources = main_file + source_cpu
35 |     extension = CppExtension
36 |     extra_compile_args = {"cxx": []}
37 |     define_macros = []
38 | 
39 |     # Force cuda since torch ask for a device, not if cuda is in fact available.
40 |     if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41 |         extension = CUDAExtension
42 |         sources += source_cuda
43 |         define_macros += [("WITH_CUDA", None)]
44 |         extra_compile_args["nvcc"] = [
45 |             "-DCUDA_HAS_FP16=1",
46 |             "-D__CUDA_NO_HALF_OPERATORS__",
47 |             "-D__CUDA_NO_HALF_CONVERSIONS__",
48 |             "-D__CUDA_NO_HALF2_OPERATORS__",
49 |         ]
50 |     else:
51 |         if CUDA_HOME is None:
52 |             raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53 |         else:
54 |             raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55 | 
56 |     sources = [os.path.join(extensions_dir, s) for s in sources]
57 |     include_dirs = [extensions_dir]
58 |     ext_modules = [
59 |         extension(
60 |             "MultiScaleDeformableAttention",
61 |             sources,
62 |             include_dirs=include_dirs,
63 |             define_macros=define_macros,
64 |             extra_compile_args=extra_compile_args,
65 |         )
66 |     ]
67 |     return ext_modules
68 | 
69 | setup(
70 |     name="MultiScaleDeformableAttention",
71 |     version="1.0",
72 |     author="Weijie Su",
73 |     url="https://github.com/fundamentalvision/Deformable-DETR",
74 |     description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75 |     packages=find_packages(exclude=("configs", "tests",)),
76 |     ext_modules=get_extensions(),
77 |     cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78 | )
79 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
 1 | /*!
 2 | **************************************************************************************************
 3 | * Deformable DETR
 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 6 | **************************************************************************************************
 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 8 | **************************************************************************************************
 9 | */
10 | 
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 | 
16 | #include <vector>
17 | 
18 | #include <ATen/ATen.h>
19 | #include <ATen/cuda/CUDAContext.h>
20 | 
21 | 
22 | at::Tensor
23 | ms_deform_attn_cpu_forward(
24 |     const at::Tensor &value, 
25 |     const at::Tensor &spatial_shapes,
26 |     const at::Tensor &level_start_index,
27 |     const at::Tensor &sampling_loc,
28 |     const at::Tensor &attn_weight,
29 |     const int im2col_step)
30 | {
31 |     AT_ERROR("Not implement on cpu");
32 | }
33 | 
34 | std::vector<at::Tensor>
35 | ms_deform_attn_cpu_backward(
36 |     const at::Tensor &value, 
37 |     const at::Tensor &spatial_shapes,
38 |     const at::Tensor &level_start_index,
39 |     const at::Tensor &sampling_loc,
40 |     const at::Tensor &attn_weight,
41 |     const at::Tensor &grad_output,
42 |     const int im2col_step)
43 | {
44 |     AT_ERROR("Not implement on cpu");
45 | }
46 | 
47 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
 1 | /*!
 2 | **************************************************************************************************
 3 | * Deformable DETR
 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 6 | **************************************************************************************************
 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 8 | **************************************************************************************************
 9 | */
10 | 
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 | 
16 | #pragma once
17 | #include <torch/extension.h>
18 | 
19 | at::Tensor
20 | ms_deform_attn_cpu_forward(
21 |     const at::Tensor &value, 
22 |     const at::Tensor &spatial_shapes,
23 |     const at::Tensor &level_start_index,
24 |     const at::Tensor &sampling_loc,
25 |     const at::Tensor &attn_weight,
26 |     const int im2col_step);
27 | 
28 | std::vector<at::Tensor>
29 | ms_deform_attn_cpu_backward(
30 |     const at::Tensor &value, 
31 |     const at::Tensor &spatial_shapes,
32 |     const at::Tensor &level_start_index,
33 |     const at::Tensor &sampling_loc,
34 |     const at::Tensor &attn_weight,
35 |     const at::Tensor &grad_output,
36 |     const int im2col_step);
37 | 
38 | 
39 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
 1 | /*!
 2 | **************************************************************************************************
 3 | * Deformable DETR
 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 6 | **************************************************************************************************
 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 8 | **************************************************************************************************
 9 | */
10 | 
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 | 
16 | #pragma once
17 | #include <torch/extension.h>
18 | 
19 | at::Tensor ms_deform_attn_cuda_forward(
20 |     const at::Tensor &value, 
21 |     const at::Tensor &spatial_shapes,
22 |     const at::Tensor &level_start_index,
23 |     const at::Tensor &sampling_loc,
24 |     const at::Tensor &attn_weight,
25 |     const int im2col_step);
26 | 
27 | std::vector<at::Tensor> ms_deform_attn_cuda_backward(
28 |     const at::Tensor &value, 
29 |     const at::Tensor &spatial_shapes,
30 |     const at::Tensor &level_start_index,
31 |     const at::Tensor &sampling_loc,
32 |     const at::Tensor &attn_weight,
33 |     const at::Tensor &grad_output,
34 |     const int im2col_step);
35 | 
36 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
 1 | /*!
 2 | **************************************************************************************************
 3 | * Deformable DETR
 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 6 | **************************************************************************************************
 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 8 | **************************************************************************************************
 9 | */
10 | 
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 | 
16 | #pragma once
17 | 
18 | #include "cpu/ms_deform_attn_cpu.h"
19 | 
20 | #ifdef WITH_CUDA
21 | #include "cuda/ms_deform_attn_cuda.h"
22 | #endif
23 | 
24 | 
25 | at::Tensor
26 | ms_deform_attn_forward(
27 |     const at::Tensor &value, 
28 |     const at::Tensor &spatial_shapes,
29 |     const at::Tensor &level_start_index,
30 |     const at::Tensor &sampling_loc,
31 |     const at::Tensor &attn_weight,
32 |     const int im2col_step)
33 | {
34 |     if (value.type().is_cuda())
35 |     {
36 | #ifdef WITH_CUDA
37 |         return ms_deform_attn_cuda_forward(
38 |             value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39 | #else
40 |         AT_ERROR("Not compiled with GPU support");
41 | #endif
42 |     }
43 |     AT_ERROR("Not implemented on the CPU");
44 | }
45 | 
46 | std::vector<at::Tensor>
47 | ms_deform_attn_backward(
48 |     const at::Tensor &value, 
49 |     const at::Tensor &spatial_shapes,
50 |     const at::Tensor &level_start_index,
51 |     const at::Tensor &sampling_loc,
52 |     const at::Tensor &attn_weight,
53 |     const at::Tensor &grad_output,
54 |     const int im2col_step)
55 | {
56 |     if (value.type().is_cuda())
57 |     {
58 | #ifdef WITH_CUDA
59 |         return ms_deform_attn_cuda_backward(
60 |             value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61 | #else
62 |         AT_ERROR("Not compiled with GPU support");
63 | #endif
64 |     }
65 |     AT_ERROR("Not implemented on the CPU");
66 | }
67 | 
68 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/src/vision.cpp:
--------------------------------------------------------------------------------
 1 | /*!
 2 | **************************************************************************************************
 3 | * Deformable DETR
 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 6 | **************************************************************************************************
 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 8 | **************************************************************************************************
 9 | */
10 | 
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 | 
16 | #include "ms_deform_attn.h"
17 | 
18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19 |   m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20 |   m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21 | }
22 | 


--------------------------------------------------------------------------------
/modeling/vision/encoder/ops/test.py:
--------------------------------------------------------------------------------
 1 | # ------------------------------------------------------------------------------------------------
 2 | # Deformable DETR
 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
 5 | # ------------------------------------------------------------------------------------------------
 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
 7 | # ------------------------------------------------------------------------------------------------
 8 | 
 9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 | 
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 | 
16 | import time
17 | import torch
18 | import torch.nn as nn
19 | from torch.autograd import gradcheck
20 | 
21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22 | 
23 | 
24 | N, M, D = 1, 2, 2
25 | Lq, L, P = 2, 2, 2
26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28 | S = sum([(H*W).item() for H, W in shapes])
29 | 
30 | 
31 | torch.manual_seed(3)
32 | 
33 | 
34 | @torch.no_grad()
35 | def check_forward_equal_with_pytorch_double():
36 |     value = torch.rand(N, S, M, D).cuda() * 0.01
37 |     sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38 |     attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39 |     attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40 |     im2col_step = 2
41 |     output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42 |     output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43 |     fwdok = torch.allclose(output_cuda, output_pytorch)
44 |     max_abs_err = (output_cuda - output_pytorch).abs().max()
45 |     max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46 | 
47 |     print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48 | 
49 | 
50 | @torch.no_grad()
51 | def check_forward_equal_with_pytorch_float():
52 |     value = torch.rand(N, S, M, D).cuda() * 0.01
53 |     sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54 |     attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55 |     attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56 |     im2col_step = 2
57 |     output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58 |     output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59 |     fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60 |     max_abs_err = (output_cuda - output_pytorch).abs().max()
61 |     max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62 | 
63 |     print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64 | 
65 | 
66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67 | 
68 |     value = torch.rand(N, S, M, channels).cuda() * 0.01
69 |     sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70 |     attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71 |     attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72 |     im2col_step = 2
73 |     func = MSDeformAttnFunction.apply
74 | 
75 |     value.requires_grad = grad_value
76 |     sampling_locations.requires_grad = grad_sampling_loc
77 |     attention_weights.requires_grad = grad_attn_weight
78 | 
79 |     gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80 | 
81 |     print(f'* {gradok} check_gradient_numerical(D={channels})')
82 | 
83 | 
84 | if __name__ == '__main__':
85 |     check_forward_equal_with_pytorch_double()
86 |     check_forward_equal_with_pytorch_float()
87 | 
88 |     for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89 |         check_gradient_numerical(channels, True, True, True)
90 | 
91 | 
92 | 
93 | 


--------------------------------------------------------------------------------
/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/pipeline/__init__.py


--------------------------------------------------------------------------------
/pipeline/utils/misc.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | import torch
 3 | 
 4 | logger = logging.getLogger(__name__)
 5 | 
 6 | def hook_opt(opt):
 7 | 
 8 |     try:
 9 |         grounding_flag = opt['REF']['INPUT']['SPATIAL']
10 |     except:
11 |         grounding_flag = False
12 | 
13 |     if grounding_flag:
14 |         opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['grounding'] = ['queries_grounding', 'tokens_grounding', 'tokens_spatial']
15 | 
16 |     try:
17 |         spatial_flag = opt['STROKE_SAMPLER']['EVAL']['GROUNDING']
18 |     except:
19 |         spatial_flag = False
20 | 
21 |     if spatial_flag:
22 |         opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['spatial'] = ['queries_spatial', 'tokens_spatial', 'memories_spatial', 'tokens_grounding']
23 | 
24 |     return opt
25 | 
26 | # HACK for evalution 
27 | def hook_metadata(metadata, name):
28 |     return metadata
29 | 
30 | # HACK for evalution 
31 | def hook_switcher(model, name):
32 |     mappings = {}
33 |     if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'context_59_val_seg', 'context_459_val_seg', 'voc_2012_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']:
34 |         mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False}
35 |     elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name:
36 |         mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False}
37 |     elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']:
38 |         mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True}
39 |     elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']:
40 |         mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True}
41 |     else:
42 |         if name not in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd", "pascalvoc_val_Point", "grounding_coco_entity_val", "vlp_coco_entity_val"]:
43 |             assert False, "dataset switcher is not defined"
44 | 
45 |     for key, value in mappings.items():
46 |         if key == 'SEMANTIC_ON':
47 |             model.model.semantic_on = value
48 |         if key == 'INSTANCE_ON':
49 |             model.model.instance_on = value
50 |         if key == 'PANOPTIC_ON':
51 |             model.model.panoptic_on = value


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [build-system]
 2 | requires = ["setuptools>=61.0"]
 3 | build-backend = "setuptools.build_meta"
 4 | 
 5 | [project]
 6 | name = "SEEM"
 7 | version = "1.0"
 8 | description = "Segment Everything Everywhere All at Once."
 9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 |     "Programming Language :: Python :: 3",
13 |     "License :: OSI Approved :: Apache Software License",
14 | ]
15 | 
16 | dependencies = [
17 |     "torch",
18 |     "torchvision",
19 |     "pillow==9.4.0",
20 |     "opencv-python==4.8.1.78",
21 |     "pyyaml==6.0.1",
22 |     "json_tricks==3.17.3",
23 |     "yacs==0.1.8",
24 |     "scikit-learn==1.3.1",
25 |     "pandas==2.0.3",
26 |     "timm==0.4.12",
27 |     "numpy==1.23.1",
28 |     "einops==0.7.0",
29 |     "fvcore==0.1.5.post20221221",
30 |     "transformers==4.34.0",
31 |     "sentencepiece==0.1.99",
32 |     "ftfy==6.1.1",
33 |     "regex==2023.10.3",
34 |     "nltk==3.8.1",
35 |     "vision-datasets==0.2.2",
36 |     "cython==3.0.2",
37 |     "pycocotools==2.0.7",
38 |     "diffdist==0.1",
39 |     "pyarrow==13.0.0",
40 |     "cityscapesscripts==2.2.2",
41 |     "shapely==1.8.0",
42 |     "scikit-image==0.21.0",
43 |     "mup==1.0.0",
44 |     "accelerate==0.23.0",
45 |     "kornia==0.7.0",
46 |     "deepspeed==0.10.3",
47 |     "wandb==0.15.12",
48 |     "infinibatch==0.1.1",
49 |     "gradio==3.42.0",     
50 |     "openai-whisper",    
51 | ]
52 | 
53 | [tool.poetry.dependencies]
54 | detectron2 = {git = "https://github.com/MaureenZOU/detectron2-xyz.git"}
55 | 
56 | 
57 | [project.urls]
58 | "Paper" = "https://arxiv.org/abs/2304.06718"
59 | "Code" = "https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/tree/v1.0"
60 | "Bug Tracker" = "https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/issues"
61 | 
62 | [tool.setuptools.packages.find]
63 | exclude = ["assets*"]
64 | 
65 | [tool.wheel]
66 | exclude = ["assets*"]
67 | 


--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .xdecoder_trainer import *


--------------------------------------------------------------------------------
/trainer/distributed_trainer.py:
--------------------------------------------------------------------------------
  1 | # --------------------------------------------------------
  2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
  3 | # Copyright (c) 2022 Microsoft
  4 | # Licensed under The MIT License [see LICENSE for details]
  5 | # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
  6 | # --------------------------------------------------------
  7 | 
  8 | import os
  9 | import logging
 10 | from mpi4py import MPI
 11 | 
 12 | import torch
 13 | 
 14 | from .utils.hook import add_hook
 15 | from .utils.mpi_adapter import MPIAdapter
 16 | from .utils.misc import save_opt_to_yaml
 17 | 
 18 | logger = logging.getLogger(__name__)
 19 | 
 20 | 
 21 | class DistributedTrainer:
 22 |     def __init__(self, opt):
 23 |         self.opt = opt
 24 | 
 25 |         # parse environment information for distributed training
 26 |         adapter = MPIAdapter(self.opt['PORT'])
 27 |         self.opt['world_size'] = adapter.world_size
 28 |         self.opt['local_size'] = adapter.local_size
 29 |         self.opt['rank'] = adapter.rank
 30 |         self.opt['local_rank'] = adapter.local_rank
 31 | 
 32 |         self.set_opt_hook()
 33 | 
 34 |         # set up device
 35 |         if not self.opt['CUDA']:
 36 |             self.opt['device'] = torch.device("cpu")
 37 |             logger.info("Using CPU")
 38 |         else:
 39 |             torch.cuda.set_device(self.opt['local_rank'])
 40 |             self.opt['device'] = torch.device("cuda", self.opt['local_rank'])
 41 |             logger.info("Using CUDA")
 42 | 
 43 |         # init distributed training
 44 |         adapter.log_info()
 45 |         if torch.distributed.is_available() and self.opt['world_size'] > 1:
 46 |             adapter.init_process_group(backend='nccl')
 47 | 
 48 |         # save config file
 49 |         self.save_folder = self.opt['SAVE_DIR']
 50 | 
 51 |         if self.opt['world_size'] > 1:
 52 |             torch.distributed.barrier()
 53 | 
 54 |         if self.opt['rank'] == 0:
 55 |             os.makedirs(self.save_folder, exist_ok=True)
 56 | 
 57 |             logger.info(f"Save config file to {os.path.join(self.save_folder, 'conf_copy.yaml')}")
 58 |             save_opt_to_yaml(self.opt, os.path.join(self.save_folder, 'conf_copy.yaml'))
 59 | 
 60 |         # ddp: log stats and update learning rate
 61 |         self.grad_acc_steps = self.opt['GRADIENT_ACCUMULATE_STEP']
 62 |         logger.info(f"Base learning rate: {self.opt['SOLVER']['BASE_LR']}")
 63 |         logger.info(f"Number of GPUs: {self.opt['world_size']}")
 64 |         logger.info(f"Gradient accumulation steps: {self.grad_acc_steps}")
 65 | 
 66 |         if self.opt['world_size'] > 1:
 67 |             add_hook()
 68 | 
 69 |         # prepare metadata for save folder
 70 |         conf_file = self.opt['conf_files'][0]
 71 |         if 'BASENAME' not in self.opt:
 72 |             self.opt['BASENAME'] = os.path.basename(conf_file)
 73 |         
 74 |         self.init_save_folder()
 75 | 
 76 |     def set_opt_hook(self):
 77 |         # Fill in the default values for required keywords
 78 |         self.opt['CUDA'] = self.opt.get('CUDA', True) and torch.cuda.is_available()
 79 |         self.opt['FP16'] = self.opt.get('FP16', False) and self.opt['CUDA']
 80 |         self.opt['GRADIENT_ACCUMULATE_STEP'] = int(self.opt.get('GRADIENT_ACCUMULATE_STEP', 1))
 81 |         self.opt['EVAL_PER_UPDATE_NUM'] = int(self.opt.get('EVAL_PER_UPDATE_NUM', 0))
 82 |         self.opt['LR_SCHEDULER_PARAMS'] = self.opt.get('LR_SCHEDULER_PARAMS', {})
 83 | 
 84 |         if 'SAVE_DIR' not in self.opt:
 85 |             assert False, "Please initialize SAVE_DIR in your config file."
 86 |         self.opt['SAVE_DIR'] = os.path.normpath(self.opt['SAVE_DIR'])
 87 |         logger.info(f"Setting SAVE_DIR as {self.opt['SAVE_DIR']}")
 88 | 
 89 |     def init_save_folder(self):
 90 |         """
 91 |         Initialize the save folder for logs, model, checkpoint, and evaluation.
 92 |         """
 93 |         runid = 1
 94 | 
 95 |         if self.opt['world_size'] > 1:
 96 |             torch.distributed.barrier()
 97 | 
 98 |         if self.opt['rank'] == 0:
 99 |             while True:
100 |                 save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
101 |                 try:
102 |                     os.makedirs(save_folder, exist_ok=False)
103 |                     break
104 |                 except FileExistsError:
105 |                     runid = runid + 1
106 | 
107 |         if self.opt['world_size'] > 1:
108 |             torch.distributed.barrier()
109 | 
110 |         if self.opt['world_size'] > 1:
111 |             runid = 1
112 |             while True:
113 |                 save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
114 |                 if not os.path.exists(save_folder):
115 |                     break
116 |                 else:
117 |                     runid += 1
118 | 
119 |             runid -= 1
120 |             save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
121 |             # this second os.makedirs() call on all ranks is to force sync the save_folder creation between blobFuse and local fs
122 |             os.makedirs(save_folder, exist_ok=True)
123 | 
124 |         self.save_folder = save_folder


--------------------------------------------------------------------------------
/trainer/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once/7b2e76dbb17d0b7831c6813a921fe2bc8de22926/trainer/utils/__init__.py


--------------------------------------------------------------------------------
/trainer/utils/hook.py:
--------------------------------------------------------------------------------
 1 | import sys
 2 | import logging
 3 | 
 4 | logger = logging.getLogger(__name__)
 5 | 
 6 | _orig_except_hook = None
 7 | 
 8 | 
 9 | def _global_except_hook(exctype, value, traceback):
10 |     """Catches an unhandled exception and call MPI_Abort()."""
11 |     try:
12 |         if _orig_except_hook:
13 |             _orig_except_hook(exctype, value, traceback)
14 |         else:
15 |             sys.__excepthook__(exctype, value, traceback)
16 | 
17 |     finally:
18 |         import mpi4py.MPI
19 |         rank = mpi4py.MPI.COMM_WORLD.Get_rank()
20 |         logger.warning("******************************************")
21 |         logger.warning("DefaultTrainer:")
22 |         logger.warning(f"   Uncaught exception on rank {rank}.")
23 |         logger.warning("   Calling MPI_Abort() to shut down MPI...")
24 |         logger.warning("******************************************")
25 |         logging.shutdown()
26 | 
27 |         try:
28 |             import mpi4py.MPI
29 |             mpi4py.MPI.COMM_WORLD.Abort(1)
30 |         except Exception as e:
31 |             # Something is completely broken...
32 |             # There's nothing we can do any more
33 |             sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n")
34 |             sys.stderr.flush()
35 |             raise e
36 | 
37 | 
38 | def add_hook():
39 |     """
40 |     Add a global hook function that captures all unhandled exceptions.
41 |     The function calls MPI_Abort() to force all processes abort.
42 | 
43 |     An MPI runtime is expected to kill all of its child processes
44 |     if one of them exits abnormally or without calling `MPI_Finalize()`.
45 |     However, when a Python program run on `mpi4py`, the MPI runtime
46 |     often fails to detect a process failure, and the rest of the processes
47 |     hang infinitely.
48 | 
49 |     See https://github.com/chainer/chainermn/issues/236 and
50 |     https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more
51 |     information.
52 |     """
53 |     global _orig_except_hook
54 | 
55 |     if _orig_except_hook is not None:
56 |         logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.")
57 |         return
58 | 
59 |     logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.")
60 |     _orig_except_hook = sys.excepthook
61 |     sys.excepthook = _global_except_hook
62 | 


--------------------------------------------------------------------------------
/trainer/utils/serialization.py:
--------------------------------------------------------------------------------
 1 | import json
 2 | import numpy as np
 3 | from typing import Dict
 4 | 
 5 | 
 6 | class JSONEncoder(json.JSONEncoder):
 7 |     def default(self, obj):
 8 |         if isinstance(obj, np.integer):
 9 |             return int(obj)
10 |         elif isinstance(obj, np.floating):
11 |             return float(obj)
12 |         elif isinstance(obj, np.ndarray):
13 |             return obj.tolist()
14 |         else:
15 |             return super(JSONEncoder, self).default(obj)
16 | 
17 | 
18 | def is_jsonable(x, json_encoder=None):
19 |     try:
20 |         json.dumps(x, cls=json_encoder)
21 |         return True
22 |     except Exception:
23 |         return False
24 | 
25 | 
26 | def filter_jsonable(data: Dict, json_encoder=None) -> Dict:
27 |     return {k: v for k, v in data.items() if is_jsonable(k, json_encoder=json_encoder) and is_jsonable(v, json_encoder=json_encoder)}


--------------------------------------------------------------------------------
/utils/Config.py:
--------------------------------------------------------------------------------
 1 | from fvcore.common.config import CfgNode as _CfgNode
 2 | 
 3 | 
 4 | class CfgNode(_CfgNode):
 5 |     """
 6 |     The same as `fvcore.common.config.CfgNode`, but different in:
 7 | 
 8 |     1. Use unsafe yaml loading by default.
 9 |        Note that this may lead to arbitrary code execution: you must not
10 |        load a config file from untrusted sources before manually inspecting
11 |        the content of the file.
12 |     2. Support config versioning.
13 |        When attempting to merge an old config, it will convert the old config automatically.
14 | 
15 |     .. automethod:: clone
16 |     .. automethod:: freeze
17 |     .. automethod:: defrost
18 |     .. automethod:: is_frozen
19 |     .. automethod:: load_yaml_with_base
20 |     .. automethod:: merge_from_list
21 |     .. automethod:: merge_from_other_cfg
22 |     """
23 | 
24 |     def merge_from_dict(self, dict):
25 |         pass
26 |     
27 | node = CfgNode()


--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt_engineering import *
2 | from .dataset import *


--------------------------------------------------------------------------------
/utils/arguments.py:
--------------------------------------------------------------------------------
 1 | import yaml
 2 | import json
 3 | import argparse
 4 | import logging
 5 | 
 6 | logger = logging.getLogger(__name__)
 7 | 
 8 | 
 9 | def load_config_dict_to_opt(opt, config_dict):
10 |     """
11 |     Load the key, value pairs from config_dict to opt, overriding existing values in opt
12 |     if there is any.
13 |     """
14 |     if not isinstance(config_dict, dict):
15 |         raise TypeError("Config must be a Python dictionary")
16 |     for k, v in config_dict.items():
17 |         k_parts = k.split('.')
18 |         pointer = opt
19 |         for k_part in k_parts[:-1]:
20 |             if k_part not in pointer:
21 |                 pointer[k_part] = {}
22 |             pointer = pointer[k_part]
23 |             assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
24 |         ori_value = pointer.get(k_parts[-1])
25 |         pointer[k_parts[-1]] = v
26 |         if ori_value:
27 |             logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
28 | 
29 | 
30 | def load_opt_from_config_files(conf_files):
31 |     """
32 |     Load opt from the config files, settings in later files can override those in previous files.
33 | 
34 |     Args:
35 |         conf_files (list): a list of config file paths
36 | 
37 |     Returns:
38 |         dict: a dictionary of opt settings
39 |     """
40 |     opt = {}
41 |     for conf_file in conf_files:
42 |         with open(conf_file, encoding='utf-8') as f:
43 |             config_dict = yaml.safe_load(f)
44 | 
45 |         load_config_dict_to_opt(opt, config_dict)
46 | 
47 |     return opt
48 | 
49 | 
50 | def load_opt_command(args):
51 |     parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')
52 |     parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
53 |     parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).')
54 |     parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
55 |     parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
56 |     parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)
57 | 
58 |     cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
59 | 
60 |     opt = load_opt_from_config_files(cmdline_args.conf_files)
61 | 
62 |     if cmdline_args.config_overrides:
63 |         config_overrides_string = ' '.join(cmdline_args.config_overrides)
64 |         logger.warning(f"Command line config overrides: {config_overrides_string}")
65 |         config_dict = json.loads(config_overrides_string)
66 |         load_config_dict_to_opt(opt, config_dict)
67 | 
68 |     if cmdline_args.overrides:
69 |         assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value"
70 |         keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
71 |         vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
72 |         vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
73 | 
74 |         types = []
75 |         for key in keys:
76 |             key = key.split('.')
77 |             ele = opt.copy()
78 |             while len(key) > 0:
79 |                 ele = ele[key.pop(0)]
80 |             types.append(type(ele))
81 |         
82 |         config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
83 |         load_config_dict_to_opt(opt, config_dict)
84 | 
85 |     # combine cmdline_args into opt dictionary
86 |     for key, val in cmdline_args.__dict__.items():
87 |         if val is not None:
88 |             opt[key] = val
89 | 
90 |     return opt, cmdline_args


--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
 1 | 
 2 | class Entity(object):
 3 |     def __init__(self, _id, _text, _mask, _interactive, _type, _start_idx, _end_idx, _image=None):
 4 |         self.id = _id
 5 |         self.text = _text
 6 |         self.mask = _mask
 7 |         self.interactive = _interactive
 8 |         self.type = _type
 9 |         self.start_idx = _start_idx
10 |         self.end_idx = _end_idx
11 | 
12 |         self.image = _image
13 | 
14 | def split_by_ordered_substrings(sentence, substrings):
15 |     results = []
16 |     substring_indices = []
17 | 
18 |     start_index = 0
19 |     for i, substring in enumerate(substrings):
20 |         # Find the start of the substring in the remaining part of the sentence
21 |         index = sentence[start_index:].find(substring)
22 | 
23 |         if index == -1:
24 |             continue
25 | 
26 |         # Append any text before the substring to the results, including spaces
27 |         if index > 0:
28 |             results.append(sentence[start_index:start_index+index])
29 |             substring_indices.append(None)  # No match in the `substrings` list for this segment
30 |         
31 |         # Append the substring to the results
32 |         results.append(substring)
33 |         substring_indices.append(i)  # Append the index from the `substrings` list
34 |         start_index += index + len(substring)
35 | 
36 |     # If there's any remaining part of the sentence after all substrings, append it to the results
37 |     if start_index < len(sentence):
38 |         results.append(sentence[start_index:])
39 |         substring_indices.append(None)  # No match in the `substrings` list for this segment
40 | 
41 |     return results, substring_indices
42 | 


--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import time
  3 | import torch
  4 | import pickle
  5 | import subprocess
  6 | 
  7 | from mpi4py import MPI
  8 | import torch.distributed as dist
  9 | 
 10 | 
 11 | def apply_distributed(opt):
 12 |     if opt['rank'] == 0:
 13 |         hostname_cmd = ["hostname -I"]
 14 |         result = subprocess.check_output(hostname_cmd, shell=True)
 15 |         master_address = result.decode('utf-8').split()[0]
 16 |         master_port = opt['PORT']
 17 |     else:
 18 |         master_address = None
 19 |         master_port = None
 20 | 
 21 |     master_address = MPI.COMM_WORLD.bcast(master_address, root=0)
 22 |     master_port = MPI.COMM_WORLD.bcast(master_port, root=0)
 23 | 
 24 |     if torch.distributed.is_available() and opt['world_size'] > 1:
 25 |         init_method_url = 'tcp://{}:{}'.format(master_address, master_port)
 26 |         backend = 'nccl'
 27 |         world_size = opt['world_size']
 28 |         rank = opt['rank']
 29 |         torch.distributed.init_process_group(backend=backend,
 30 |                                              init_method=init_method_url,
 31 |                                              world_size=world_size,
 32 |                                              rank=rank)
 33 | 
 34 | def init_distributed(opt):
 35 |     opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available()
 36 |     if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
 37 |         # application was started without MPI
 38 |         # default to single node with single process
 39 |         opt['env_info'] = 'no MPI'
 40 |         opt['world_size'] = 1
 41 |         opt['local_size'] = 1
 42 |         opt['rank'] = 0
 43 |         opt['local_rank'] = 0
 44 |         opt['master_address'] = '127.0.0.1'
 45 |         opt['master_port'] = '8673'
 46 |     else:
 47 |         # application was started with MPI
 48 |         # get MPI parameters
 49 |         opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE'])
 50 |         opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
 51 |         opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK'])
 52 |         opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
 53 | 
 54 |     # set up device
 55 |     if not opt['CUDA']:
 56 |         assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend'
 57 |         opt['device'] = torch.device("cpu")
 58 |     else:
 59 |         torch.cuda.set_device(opt['local_rank'])
 60 |         opt['device'] = torch.device("cuda", opt['local_rank'])
 61 | 
 62 |     apply_distributed(opt)
 63 |     return opt
 64 | 
 65 | def is_main_process():
 66 |     rank = 0
 67 |     if 'OMPI_COMM_WORLD_SIZE' in os.environ:
 68 |         rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
 69 | 
 70 |     return rank == 0
 71 | 
 72 | def get_world_size():
 73 |     if not dist.is_available():
 74 |         return 1
 75 |     if not dist.is_initialized():
 76 |         return 1
 77 |     return dist.get_world_size()
 78 | 
 79 | def get_rank():
 80 |     if not dist.is_available():
 81 |         return 0
 82 |     if not dist.is_initialized():
 83 |         return 0
 84 |     return dist.get_rank()
 85 | 
 86 | 
 87 | def synchronize():
 88 |     """
 89 |     Helper function to synchronize (barrier) among all processes when
 90 |     using distributed training
 91 |     """
 92 |     if not dist.is_available():
 93 |         return
 94 |     if not dist.is_initialized():
 95 |         return
 96 |     world_size = dist.get_world_size()
 97 |     rank = dist.get_rank()
 98 |     if world_size == 1:
 99 |         return
100 | 
101 |     def _send_and_wait(r):
102 |         if rank == r:
103 |             tensor = torch.tensor(0, device="cuda")
104 |         else:
105 |             tensor = torch.tensor(1, device="cuda")
106 |         dist.broadcast(tensor, r)
107 |         while tensor.item() == 1:
108 |             time.sleep(1)
109 | 
110 |     _send_and_wait(0)
111 |     # now sync on the main process
112 |     _send_and_wait(1)


--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
 1 | # --------------------------------------------------------
 2 | # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
 3 | # Copyright (c) 2022 Microsoft
 4 | # Licensed under The MIT License [see LICENSE for details]
 5 | # Written by Xueyan Zou (xueyan@cs.wisc.edu)
 6 | # --------------------------------------------------------
 7 | import math
 8 | 
 9 | 
10 | 
11 | class AverageMeter(object):
12 |     """Computes and stores the average and current value."""
13 |     def __init__(self):
14 |         self.reset()
15 | 
16 |     def reset(self):
17 |         self.val = 0
18 |         self.avg = 0
19 |         self.sum = 0
20 |         self.count = 0
21 | 
22 |     def update(self, val, n=1, decay=0):
23 |         self.val = val
24 |         if decay:
25 |             alpha = math.exp(-n / decay)  # exponential decay over 100 updates
26 |             self.sum = alpha * self.sum + (1 - alpha) * val * n
27 |             self.count = alpha * self.count + (1 - alpha) * n
28 |         else:
29 |             self.sum += val * n
30 |             self.count += n
31 |         self.avg = self.sum / self.count
32 | 


--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | import os
 3 | import time
 4 | import pickle
 5 | import torch
 6 | import torch.nn as nn
 7 | 
 8 | from utils.distributed import is_main_process
 9 | 
10 | logger = logging.getLogger(__name__)
11 | 
12 | 
13 | NORM_MODULES = [
14 |     torch.nn.BatchNorm1d,
15 |     torch.nn.BatchNorm2d,
16 |     torch.nn.BatchNorm3d,
17 |     torch.nn.SyncBatchNorm,
18 |     # NaiveSyncBatchNorm inherits from BatchNorm2d
19 |     torch.nn.GroupNorm,
20 |     torch.nn.InstanceNorm1d,
21 |     torch.nn.InstanceNorm2d,
22 |     torch.nn.InstanceNorm3d,
23 |     torch.nn.LayerNorm,
24 |     torch.nn.LocalResponseNorm,
25 | ]
26 | 
27 | def register_norm_module(cls):
28 |     NORM_MODULES.append(cls)
29 |     return cls
30 | 
31 | def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
32 |     model_keys = sorted(model_state_dict.keys())
33 |     ckpt_keys = sorted(ckpt_state_dict.keys())
34 |     result_dicts = {}
35 |     matched_log = []
36 |     unmatched_log = []
37 |     unloaded_log = []
38 |     for model_key in model_keys:
39 |         model_weight = model_state_dict[model_key]
40 |         if model_key in ckpt_keys:
41 |             ckpt_weight = ckpt_state_dict[model_key]
42 |             if model_weight.shape == ckpt_weight.shape:
43 |                 result_dicts[model_key] = ckpt_weight
44 |                 ckpt_keys.pop(ckpt_keys.index(model_key))
45 |                 matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
46 |             else:
47 |                 unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
48 |         else:
49 |             unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
50 |             
51 |     if is_main_process():
52 |         for info in matched_log:
53 |             logger.info(info)
54 |         for info in unloaded_log:
55 |             logger.warning(info)
56 |         for key in ckpt_keys:
57 |             logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
58 |         for info in unmatched_log:
59 |             logger.warning(info)
60 |     return result_dicts


--------------------------------------------------------------------------------
/utils/prompt_engineering.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | 
 3 | 
 4 | def get_prompt_templates():
 5 |     prompt_templates = [
 6 |         '{}.',
 7 |         'a photo of a {}.',
 8 |         'a bad photo of a {}.',
 9 |         'a photo of many {}.',
10 |         'a sculpture of a {}.',
11 |         'a photo of the hard to see {}.',
12 |         'a low resolution photo of the {}.',
13 |         'a rendering of a {}.',
14 |         'graffiti of a {}.',
15 |         'a bad photo of the {}.',
16 |         'a cropped photo of the {}.',
17 |         'a tattoo of a {}.',
18 |         'the embroidered {}.',
19 |         'a photo of a hard to see {}.',
20 |         'a bright photo of a {}.',
21 |         'a photo of a clean {}.',
22 |         'a photo of a dirty {}.',
23 |         'a dark photo of the {}.',
24 |         'a drawing of a {}.',
25 |         'a photo of my {}.',
26 |         'the plastic {}.',
27 |         'a photo of the cool {}.',
28 |         'a close-up photo of a {}.',
29 |         'a black and white photo of the {}.',
30 |         'a painting of the {}.',
31 |         'a painting of a {}.',
32 |         'a pixelated photo of the {}.',
33 |         'a sculpture of the {}.',
34 |         'a bright photo of the {}.',
35 |         'a cropped photo of a {}.',
36 |         'a plastic {}.',
37 |         'a photo of the dirty {}.',
38 |         'a jpeg corrupted photo of a {}.',
39 |         'a blurry photo of the {}.',
40 |         'a photo of the {}.',
41 |         'a good photo of the {}.',
42 |         'a rendering of the {}.',
43 |         'a {} in a video game.',
44 |         'a photo of one {}.',
45 |         'a doodle of a {}.',
46 |         'a close-up photo of the {}.',
47 |         'the origami {}.',
48 |         'the {} in a video game.',
49 |         'a sketch of a {}.',
50 |         'a doodle of the {}.',
51 |         'a origami {}.',
52 |         'a low resolution photo of a {}.',
53 |         'the toy {}.',
54 |         'a rendition of the {}.',
55 |         'a photo of the clean {}.',
56 |         'a photo of a large {}.',
57 |         'a rendition of a {}.',
58 |         'a photo of a nice {}.',
59 |         'a photo of a weird {}.',
60 |         'a blurry photo of a {}.',
61 |         'a cartoon {}.',
62 |         'art of a {}.',
63 |         'a sketch of the {}.',
64 |         'a embroidered {}.',
65 |         'a pixelated photo of a {}.',
66 |         'itap of the {}.',
67 |         'a jpeg corrupted photo of the {}.',
68 |         'a good photo of a {}.',
69 |         'a plushie {}.',
70 |         'a photo of the nice {}.',
71 |         'a photo of the small {}.',
72 |         'a photo of the weird {}.',
73 |         'the cartoon {}.',
74 |         'art of the {}.',
75 |         'a drawing of the {}.',
76 |         'a photo of the large {}.',
77 |         'a black and white photo of a {}.',
78 |         'the plushie {}.',
79 |         'a dark photo of a {}.',
80 |         'itap of a {}.',
81 |         'graffiti of the {}.',
82 |         'a toy {}.',
83 |         'itap of my {}.',
84 |         'a photo of a cool {}.',
85 |         'a photo of a small {}.',
86 |         'a tattoo of the {}.',
87 |     ]
88 |     return prompt_templates
89 | 
90 | def prompt_engineering(classnames, topk=1, suffix='.'):
91 |     prompt_templates = get_prompt_templates()
92 |     temp_idx = np.random.randint(min(len(prompt_templates), topk))
93 | 
94 |     if isinstance(classnames, list):
95 |         classname = random.choice(classnames)
96 |     else:
97 |         classname = classnames
98 | 
99 |     return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))


--------------------------------------------------------------------------------