├── .gitignore ├── CODE_OF_CONDUCT.md ├── INSTALL.md ├── LICENSE ├── README.md ├── configs ├── ade20k-full-847-freq │ ├── Base-ADE20KFull-847.yaml │ ├── zegformer_R50_bs32_60k_vit16_ade20k-full-847-freq.yaml │ └── zegformer_R50_bs32_60k_vit16_ade20k-full-847-freq_gzss_eval.yaml ├── coco-stuff │ ├── Base-COCOStuff-171.yaml │ ├── zegformer_R101_bs32_60k_vit16_coco-stuff.yaml │ ├── zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval.yaml │ ├── zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval_847_classes.yaml │ ├── zegformer_R50_bs32_60k_vit16_coco-stuff.yaml │ └── zegformer_R50_bs32_60k_vit16_coco-stuff_gzss_eval.yaml └── pascal_voc │ ├── Base-VOC-20.yaml │ ├── zegformer_R101_bs32_10k_vit16_voc.yaml │ ├── zegformer_R101_bs32_10k_vit16_voc_gzss_eval.yaml │ ├── zegformer_R101_bs32_30k_vit16_voc.yaml │ ├── zegformer_R101_bs32_30k_vit16_voc_gzss_eval.yaml │ ├── zegformer_R50_bs32_10k_vit16_voc.yaml │ └── zegformer_R50_bs32_10k_vit16_voc_gzss_eval.yaml ├── datasets ├── README.md ├── ade20k-full-frequency-split │ ├── create_ade-frequency_json.py │ ├── prepare_ade20k_full_frequency_all_val.py │ ├── prepare_ade20k_full_frequency_seen.py │ └── prepare_ade20k_full_frequency_unseen_val.py ├── coco-stuff │ ├── create_cocostuff_class_names_json.py │ ├── create_cocostuff_seen_wordvecindexes_json.py │ ├── prepare_coco_stuff_sem_seg_seen.py │ ├── prepare_coco_stuff_sem_seg_unseen.py │ └── prepare_coco_stuff_sem_seg_val_all.py ├── coco │ └── coco_stuff │ │ ├── split │ │ ├── novel_cls.npy │ │ ├── seen_cls.npy │ │ └── val_cls.npy │ │ └── word_vectors │ │ ├── fasttext.pkl │ │ ├── glove.pkl │ │ └── word2vec.pkl ├── pascal │ ├── create_voc_class_names_json.py │ ├── prepare_pascal_voc_seen.py │ ├── prepare_pascal_voc_unseen_val.py │ └── prepare_pascal_voc_val_all.py └── voc12 │ └── split │ ├── novel_cls.npy │ ├── seen_cls.npy │ ├── test_list.npy │ └── train_list.npy ├── demo ├── demo.py ├── demo_visual_gt.py ├── demo_visual_gt_adefull.py ├── predictor.py └── visualizer.py ├── figures ├── adeinferenceCOCO.png └── fig1.png ├── mask_former ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ ├── dataset_mappers │ │ ├── __init__.py │ │ ├── detr_panoptic_dataset_mapper.py │ │ ├── mask_former_panoptic_dataset_mapper.py │ │ └── mask_former_semantic_dataset_mapper.py │ └── datasets │ │ ├── __init__.py │ │ ├── register_ade20k_full_zero_freq.py │ │ ├── register_coco_stuff.py │ │ └── register_pascal_voc.py ├── evaluation │ ├── __init__.py │ └── sem_seg_evaluation_gzero.py ├── mask_former_model.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ └── swin.py │ ├── criterion.py │ ├── heads │ │ ├── __init__.py │ │ ├── mask_former_head.py │ │ ├── per_pixel_baseline.py │ │ ├── pixel_decoder.py │ │ ├── zeg_former_head.py │ │ └── zeroshot_per_pixel_baseline.py │ ├── matcher.py │ └── transformer │ │ ├── __init__.py │ │ ├── position_encoding.py │ │ ├── transformer.py │ │ ├── transformer_predictor.py │ │ └── transformer_zeroshot_predictor.py ├── semantic_seg_zero.py ├── test_time_augmentation.py ├── third_party │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── imagenet_templates.py │ ├── model.py │ └── simple_tokenizer.py └── utils │ ├── __init__.py │ └── misc.py ├── plain_train_net.py ├── requirements.txt ├── tools ├── README.md ├── convert-pretrained-swin-model-to-d2.py ├── convert-torchvision-to-d2.py └── sem_seg_json2mat.py └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | # output dir 3 | output 4 | instant_test_output 5 | inference_test_output 6 | 7 | 8 | *.png 9 | *.json 10 | *.diff 11 | *.jpg 12 | !/projects/DensePose/doc/images/*.jpg 13 | 14 | # compilation and distribution 15 | __pycache__ 16 | _ext 17 | *.pyc 18 | *.pyd 19 | *.so 20 | *.dll 21 | *.egg-info/ 22 | build/ 23 | dist/ 24 | wheels/ 25 | 26 | # pytorch/python/numpy formats 27 | *.pth 28 | #*.pkl 29 | #*.npy 30 | *.ts 31 | model_ts*.txt 32 | 33 | # ipython/jupyter notebooks 34 | *.ipynb 35 | **/.ipynb_checkpoints/ 36 | 37 | # Editor temporaries 38 | *.swn 39 | *.swo 40 | *.swp 41 | *~ 42 | 43 | # editor settings 44 | .idea 45 | .vscode 46 | _darcs 47 | .zip 48 | # project dirs 49 | /detectron2/model_zoo/configs 50 | #/datasets/* 51 | #!/datasets/*.* 52 | #/projects/*/datasets 53 | /models 54 | /snippet 55 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Requirements 4 | - Linux or macOS with Python ≥ 3.6 5 | - PyTorch ≥ 1.7 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 6 | Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check 7 | PyTorch version matches that is required by Detectron2. 8 | - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). 9 | - OpenCV is optional but needed by demo and visualization 10 | - `pip install -r requirements.txt` 11 | 12 | An example of installation is shown beloe: 13 | 14 | ``` 15 | conda create -n zegformer python==3.7 16 | conda activate zegformer 17 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 18 | python -m pip install detectron2 -f \ 19 | https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html 20 | 21 | git clone https://github.com/dingjiansw101/ZegFormer.git 22 | cd ZegFormer 23 | pip install -r requirements.txt 24 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Decoupling Zero-Shot Semantic Segmentation 3 | This is the official code for the [ZegFormer](https://arxiv.org/abs/2112.07910) (CVPR 2022). 4 | 5 | ZegFormer is the first framework that decouple the zero-shot semantic segmentation into: 1) class-agnostic segmentation and 2) segment-level zero-shot classification 6 | 7 | [comment]: <> (![fig1](figures/fig1.png)) 8 | ### Visualization of semantic segmentation with open vocabularies 9 | ZegFormer is able to segment stuff and things with open vocabularies. The predicted classes can be more fine-grained 10 | than the COCO-Stuff annotations (see colored boxes below). 11 | 12 | [comment]: <> (The unannotated vocabularies in COCO-Stuff can also be segmented by ZegFormer.)) 13 | ![visualization](figures/adeinferenceCOCO.png) 14 | 15 | [comment]: <> (### Benchmark Results) 16 | 17 | ### Data Preparation 18 | See [data preparation](datasets/README.md) 19 | 20 | ### Config files 21 | For each model, there are two kinds of config files. The file without suffix "_gzss_eval" is used for training. The file with suffix "_gzss_eval" 22 | is used for generalized zero-shot semantic segmentation evaluation. 23 | 24 | ### Inference Demo with Pre-trained Models 25 | Download the checkpoints of ZegFormer from https://drive.google.com/drive/u/0/folders/1qcIe2mE1VRU1apihsao4XvANJgU5lYgm 26 | ``` 27 | python demo/demo.py --config-file configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval.yaml \ 28 | --input input1.jpg input2.jpg \ 29 | [--other-options] 30 | --opts MODEL.WEIGHTS /path/to/zegformer_R101_bs32_60k_vit16_coco-stuff.pth 31 | ``` 32 | The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation. 33 | This command will run the inference and show visualizations in an OpenCV window. 34 | 35 | For details of the command line arguments, see `demo.py -h` or look at its source code 36 | to understand its behavior. Some common arguments are: 37 | * To run __on your webcam__, replace `--input files` with `--webcam`. 38 | * To run __on a video__, replace `--input files` with `--video-input video.mp4`. 39 | * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`. 40 | * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`. 41 | 42 | #### Inference with more classnames 43 | 44 | In the example above, the model is trained with __156 classes__, and inferenced with __171 classes__. 45 | 46 | If you want to inference with more classes, try the config `zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval_847_classes.yaml`. 47 | 48 | [comment]: <> (You can also generate your customized json __TEST_CLASS_JSON with arbitrary class names__ by yourself.) 49 | 50 | 51 | ### Training & Evaluation in Command Line 52 | To train models with R-101 backbone, download the pre-trained model 53 | [R-101.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-101.pkl), which is a converted copy of [MSRA's original ResNet-101](https://github.com/KaimingHe/deep-residual-networks) model. 54 | 55 | 56 | We provide two scripts in `train_net.py`, that are made to train all the configs provided in MaskFormer. 57 | 58 | To train a model with "train_net.py", first 59 | setup the corresponding datasets following 60 | [datasets/README.md](./datasets/README.md), 61 | then run: 62 | ``` 63 | ./train_net.py --num-gpus 8 \ 64 | --config-file configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff.yaml 65 | ``` 66 | 67 | The configs are made for 8-GPU training. 68 | Since we use ADAMW optimizer, it is not clear how to scale learning rate with batch size. 69 | To train on 1 GPU, you need to figure out learning rate and batch size by yourself: 70 | ``` 71 | ./train_net.py \ 72 | --config-file configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff.yaml \ 73 | --num-gpus 1 SOLVER.IMS_PER_BATCH SET_TO_SOME_REASONABLE_VALUE SOLVER.BASE_LR SET_TO_SOME_REASONABLE_VALUE 74 | ``` 75 | 76 | To evaluate a model's performance, use 77 | ``` 78 | ./train_net.py \ 79 | --config-file configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval.yaml \ 80 | --eval-only MODEL.WEIGHTS /path/to/checkpoint_file 81 | ``` 82 | For more options, see `./train_net.py -h`. 83 | 84 | The pre-trained checkpoints of ZegFormer can be downloaded from https://drive.google.com/drive/folders/1qcIe2mE1VRU1apihsao4XvANJgU5lYgm?usp=sharing 85 | 86 | ## Disclaimer 87 | Although the reported results on PASCAL VOC are trained with 10k iterations, the results at 10k are not stable. We recommend to train models with longer iterations. 88 | ## Acknowlegment 89 | This repo benefits from [CLIP](https://github.com/openai/CLIP) and [MaskFormer](https://github.com/facebookresearch/MaskFormer). Thanks for their wonderful works. 90 | 91 | ## Citation 92 | ``` 93 | @article{ding2021decoupling, 94 | title={Decoupling Zero-Shot Semantic Segmentation}, 95 | author={Ding, Jian and Xue, Nan and Xia, Gui-Song and Dai, Dengxin}, 96 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 97 | year={2022} 98 | } 99 | ``` 100 | 101 | If you have any problems in using this code, please contact me (jian.ding@whu.edu.cn) 102 | -------------------------------------------------------------------------------- /configs/ade20k-full-847-freq/Base-ADE20KFull-847.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("ade20k_full_sem_seg_freq_train",) 18 | TEST: ("ade20k_full_sem_seg_freq_val",) 19 | SOLVER: 20 | IMS_PER_BATCH: 16 21 | BASE_LR: 0.0001 22 | MAX_ITER: 200000 23 | WARMUP_FACTOR: 1.0 24 | WARMUP_ITERS: 0 25 | WEIGHT_DECAY: 0.0001 26 | OPTIMIZER: "ADAMW" 27 | LR_SCHEDULER_NAME: "WarmupPolyLR" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"] 36 | MIN_SIZE_TRAIN_SAMPLING: "choice" 37 | MIN_SIZE_TEST: 512 38 | MAX_SIZE_TRAIN: 2048 39 | MAX_SIZE_TEST: 2048 40 | CROP: 41 | ENABLED: True 42 | TYPE: "absolute" 43 | SIZE: (512, 512) 44 | SINGLE_CATEGORY_MAX_AREA: 1.0 45 | COLOR_AUG_SSD: True 46 | SIZE_DIVISIBILITY: 512 # used in dataset mapper 47 | FORMAT: "RGB" 48 | DATASET_MAPPER_NAME: "mask_former_semantic" 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | DATALOADER: 52 | FILTER_EMPTY_ANNOTATIONS: True 53 | NUM_WORKERS: 4 54 | VERSION: 2 55 | -------------------------------------------------------------------------------- /configs/ade20k-full-847-freq/zegformer_R50_bs32_60k_vit16_ade20k-full-847-freq.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20KFull-847.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "ZegFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 65535 8 | NUM_CLASSES: 572 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | TRAIN_CLASS_JSON: "datasets/ADE20K_2021_17_01/ADE20K_572_pure_class.json" 15 | TEST_CLASS_JSON: "datasets/ADE20K_2021_17_01/ADE20K_572_pure_class.json" 16 | CLIP_PRETRAINED: "ViT-B/16" 17 | MASK_FORMER: 18 | TRANSFORMER_IN_FEATURE: "res5" 19 | DEEP_SUPERVISION: True 20 | NO_OBJECT_WEIGHT: 0.1 21 | DICE_WEIGHT: 1.0 22 | MASK_WEIGHT: 20.0 23 | HIDDEN_DIM: 256 24 | NUM_OBJECT_QUERIES: 100 25 | NHEADS: 8 26 | DROPOUT: 0.1 27 | DIM_FEEDFORWARD: 2048 28 | ENC_LAYERS: 0 29 | DEC_LAYERS: 6 30 | PRE_NORM: False 31 | PROMPT_ENSEMBLE_TYPE: "imagenet_select" 32 | SOLVER: 33 | AMP: 34 | ENABLED: False 35 | IMS_PER_BATCH: 32 36 | # IMS_PER_BATCH: 8 37 | MAX_ITER: 60000 38 | DATALOADER: 39 | NUM_WORKERS: 8 40 | CUDNN_BENCHMARK: True 41 | -------------------------------------------------------------------------------- /configs/ade20k-full-847-freq/zegformer_R50_bs32_60k_vit16_ade20k-full-847-freq_gzss_eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20KFull-847.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "ZegFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 65535 8 | NUM_CLASSES: 847 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | TRAIN_CLASS_JSON: "datasets/ADE20K_2021_17_01/ADE20K_572_pure_class.json" 15 | TEST_CLASS_JSON: "datasets/ADE20K_2021_17_01/ADE20K_847_pure_class.json" 16 | CLIP_PRETRAINED: "ViT-B/16" 17 | CLIP_CLASSIFICATION: True 18 | MASK_FORMER: 19 | TRANSFORMER_IN_FEATURE: "res5" 20 | DEEP_SUPERVISION: True 21 | NO_OBJECT_WEIGHT: 0.1 22 | DICE_WEIGHT: 1.0 23 | MASK_WEIGHT: 20.0 24 | HIDDEN_DIM: 256 25 | NUM_OBJECT_QUERIES: 100 26 | NHEADS: 8 27 | DROPOUT: 0.1 28 | DIM_FEEDFORWARD: 2048 29 | ENC_LAYERS: 0 30 | DEC_LAYERS: 6 31 | PRE_NORM: False 32 | GZERO_CALIBRATE: 0.0 33 | # ENSEMBLING_ALL_CLS: True 34 | ENSEMBLING: True 35 | PROMPT_ENSEMBLE_TYPE: "imagenet_select" 36 | # PROMPT_ENSEMBLE_TYPE: "single" 37 | SOLVER: 38 | AMP: 39 | ENABLED: True 40 | IMS_PER_BATCH: 32 41 | # IMS_PER_BATCH: 8 42 | MAX_ITER: 60000 43 | DATASETS: 44 | TRAIN: ("ade20k_full_sem_seg_freq_train",) 45 | TEST: ("ade20k_full_sem_seg_freq_val_all",) 46 | DATALOADER: 47 | NUM_WORKERS: 8 48 | CUDNN_BENCHMARK: True 49 | -------------------------------------------------------------------------------- /configs/coco-stuff/Base-COCOStuff-171.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("coco_2017_train_stuff_seen_sem_seg",) 18 | TEST: ("coco_2017_test_stuff_seen_sem_seg",) 19 | SOLVER: 20 | IMS_PER_BATCH: 32 21 | BASE_LR: 0.0001 22 | MAX_ITER: 60000 23 | # MAX_ITER: 100000 24 | WARMUP_FACTOR: 1.0 25 | WARMUP_ITERS: 0 26 | WEIGHT_DECAY: 0.0001 27 | OPTIMIZER: "ADAMW" 28 | LR_SCHEDULER_NAME: "WarmupPolyLR" 29 | BACKBONE_MULTIPLIER: 0.1 30 | CLIP_GRADIENTS: 31 | ENABLED: True 32 | CLIP_TYPE: "full_model" 33 | CLIP_VALUE: 0.01 34 | NORM_TYPE: 2.0 35 | CHECKPOINT_PERIOD: 1000 # TODO: comment this 36 | INPUT: 37 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 16)]"] 38 | MIN_SIZE_TRAIN_SAMPLING: "choice" 39 | MIN_SIZE_TEST: 640 40 | MAX_SIZE_TRAIN: 2560 41 | MAX_SIZE_TEST: 2560 42 | CROP: 43 | ENABLED: True 44 | TYPE: "absolute" 45 | SIZE: (640, 640) 46 | SINGLE_CATEGORY_MAX_AREA: 1.0 47 | COLOR_AUG_SSD: True 48 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 49 | FORMAT: "RGB" 50 | DATASET_MAPPER_NAME: "mask_former_semantic" 51 | TEST: 52 | EVAL_PERIOD: 5000 53 | AUG: 54 | ENABLED: False 55 | MIN_SIZES: [320, 480, 640, 800, 960, 1120] 56 | MAX_SIZE: 4480 57 | FLIP: True 58 | DATALOADER: 59 | FILTER_EMPTY_ANNOTATIONS: True 60 | NUM_WORKERS: 8 61 | VERSION: 2 62 | CUDNN_BENCHMARK: True 63 | -------------------------------------------------------------------------------- /configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_60k_vit16_coco-stuff.yaml 2 | MODEL: 3 | # BACKBONE: 4 | # NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "R-101.pkl" 6 | META_ARCHITECTURE: "MaskFormer" 7 | RESNETS: 8 | DEPTH: 101 9 | STEM_TYPE: "basic" # not used 10 | STEM_OUT_CHANNELS: 64 11 | STRIDE_IN_1X1: False 12 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 13 | # NORM: "SyncBN" 14 | RES5_MULTI_GRID: [1, 1, 1] # not used 15 | -------------------------------------------------------------------------------- /configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R101_bs32_60k_vit16_coco-stuff.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NUM_CLASSES: 171 6 | CLIP_CLASSIFICATION: True 7 | TRAIN_CLASS_JSON: "datasets/coco/coco_stuff/split/seen_classnames.json" 8 | TEST_CLASS_JSON: "datasets/coco/coco_stuff/split/all_classnames.json" 9 | MASK_FORMER: 10 | GZERO_CALIBRATE: 0.1 11 | # ENSEMBLING: False 12 | ENSEMBLING: True 13 | PROMPT_ENSEMBLE_TYPE: "single" 14 | DATASETS: 15 | TEST: ("coco_2017_val_all_stuff_sem_seg",) -------------------------------------------------------------------------------- /configs/coco-stuff/zegformer_R101_bs32_60k_vit16_coco-stuff_gzss_eval_847_classes.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R101_bs32_60k_vit16_coco-stuff.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | IGNORE_VALUE: 65535 6 | NUM_CLASSES: 847 7 | CLIP_CLASSIFICATION: True 8 | TRAIN_CLASS_JSON: "datasets/coco/coco_stuff/split/seen_classnames.json" 9 | TEST_CLASS_JSON: "datasets/ADE20K_2021_17_01/ADE20K_847_pure_class.json" 10 | MASK_FORMER: 11 | GZERO_CALIBRATE: 0.0 12 | # ENSEMBLING: False 13 | # ENSEMBLING: True 14 | PROMPT_ENSEMBLE_TYPE: "imagenet_select" 15 | DATASETS: 16 | TEST: ("ade20k_full_sem_seg_freq_val_all",) -------------------------------------------------------------------------------- /configs/coco-stuff/zegformer_R50_bs32_60k_vit16_coco-stuff.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-COCOStuff-171.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "ZegFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 156 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | TRAIN_CLASS_JSON: "datasets/coco/coco_stuff/split/seen_classnames.json" 15 | TEST_CLASS_JSON: "datasets/coco/coco_stuff/split/seen_classnames.json" 16 | CLIP_PRETRAINED: "ViT-B/16" 17 | # CLIP_CLASSIFICATION: True 18 | PROMPT_ENSEMBLE_TYPE: "imagenet_select" 19 | MASK_FORMER: 20 | TRANSFORMER_IN_FEATURE: "res5" 21 | DEEP_SUPERVISION: True 22 | NO_OBJECT_WEIGHT: 0.1 23 | DICE_WEIGHT: 1.0 24 | MASK_WEIGHT: 20.0 25 | HIDDEN_DIM: 256 26 | NUM_OBJECT_QUERIES: 100 27 | NHEADS: 8 28 | DROPOUT: 0.1 29 | DIM_FEEDFORWARD: 2048 30 | ENC_LAYERS: 0 31 | DEC_LAYERS: 6 32 | PRE_NORM: False 33 | SOLVER: 34 | IMS_PER_BATCH: 32 35 | # BASE_LR: 0.0001 36 | MAX_ITER: 60000 37 | AMP: 38 | ENABLED: True 39 | DATALOADER: 40 | FILTER_EMPTY_ANNOTATIONS: True 41 | # NUM_WORKERS: 4 42 | NUM_WORKERS: 8 43 | CUDNN_BENCHMARK: True 44 | -------------------------------------------------------------------------------- /configs/coco-stuff/zegformer_R50_bs32_60k_vit16_coco-stuff_gzss_eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_60k_vit16_coco-stuff.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NUM_CLASSES: 171 6 | CLIP_PRETRAINED: "ViT-B/16" 7 | CLIP_CLASSIFICATION: True 8 | TRAIN_CLASS_JSON: "datasets/coco/coco_stuff/split/seen_classnames.json" 9 | TEST_CLASS_JSON: "datasets/coco/coco_stuff/split/all_classnames.json" 10 | MASK_FORMER: 11 | GZERO_CALIBRATE: 0.1 12 | ENSEMBLING: True 13 | PROMPT_ENSEMBLE_TYPE: "imagenet_select" 14 | DATASETS: 15 | TEST: ("coco_2017_val_all_stuff_sem_seg",) -------------------------------------------------------------------------------- /configs/pascal_voc/Base-VOC-20.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("pascal_voc_train_seen_sem_seg",) 18 | TEST: ("pascal_voc_test_seen_sem_seg",) 19 | SOLVER: 20 | IMS_PER_BATCH: 32 21 | BASE_LR: 0.0001 22 | MAX_ITER: 10000 23 | # MAX_ITER: 100000 24 | WARMUP_FACTOR: 1.0 25 | WARMUP_ITERS: 0 26 | WEIGHT_DECAY: 0.0001 27 | OPTIMIZER: "ADAMW" 28 | LR_SCHEDULER_NAME: "WarmupPolyLR" 29 | BACKBONE_MULTIPLIER: 0.1 30 | CLIP_GRADIENTS: 31 | ENABLED: True 32 | CLIP_TYPE: "full_model" 33 | CLIP_VALUE: 0.01 34 | NORM_TYPE: 2.0 35 | INPUT: 36 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"] 37 | MIN_SIZE_TRAIN_SAMPLING: "choice" 38 | MIN_SIZE_TEST: 512 39 | MAX_SIZE_TRAIN: 2048 40 | MAX_SIZE_TEST: 2048 41 | CROP: 42 | ENABLED: True 43 | TYPE: "absolute" 44 | SIZE: (512, 512) 45 | SINGLE_CATEGORY_MAX_AREA: 1.0 46 | COLOR_AUG_SSD: True 47 | SIZE_DIVISIBILITY: 512 # used in dataset mapper 48 | FORMAT: "RGB" 49 | DATASET_MAPPER_NAME: "mask_former_semantic" 50 | TEST: 51 | EVAL_PERIOD: 10000 52 | AUG: 53 | ENABLED: False 54 | MIN_SIZES: [256, 384, 512, 640, 768, 896] 55 | MAX_SIZE: 3584 56 | FLIP: True 57 | DATALOADER: 58 | FILTER_EMPTY_ANNOTATIONS: True 59 | NUM_WORKERS: 4 60 | VERSION: 2 61 | CUDNN_BENCHMARK: True 62 | -------------------------------------------------------------------------------- /configs/pascal_voc/zegformer_R101_bs32_10k_vit16_voc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_10k_vit16_voc.yaml 2 | MODEL: 3 | # BACKBONE: 4 | # NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "R-101.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "basic" # not used 9 | STEM_OUT_CHANNELS: 64 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 1, 1] # not used 14 | 15 | #SEED: 28057409 16 | 17 | DATALOADER: 18 | FILTER_EMPTY_ANNOTATIONS: True 19 | NUM_WORKERS: 1 -------------------------------------------------------------------------------- /configs/pascal_voc/zegformer_R101_bs32_10k_vit16_voc_gzss_eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_10k_vit16_voc.yaml 2 | MODEL: 3 | # BACKBONE: 4 | # NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "R-101.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "basic" # not used 9 | STEM_OUT_CHANNELS: 64 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 1, 1] # not used 14 | SEM_SEG_HEAD: 15 | NUM_CLASSES: 20 16 | CLIP_CLASSIFICATION: True 17 | TRAIN_CLASS_JSON: "datasets/VOCZERO/seen_classnames.json" 18 | TEST_CLASS_JSON: "datasets/VOCZERO/all_classnames.json" 19 | MASK_FORMER: 20 | GZERO_CALIBRATE: 0.1 21 | # GZERO_CALIBRATE: 0.0 22 | ENSEMBLING: True 23 | DATASETS: 24 | # TRAIN: ("ade20k_full_sem_seg_train",) 25 | TEST: ("pascal_voc_val_all_sem_seg",) -------------------------------------------------------------------------------- /configs/pascal_voc/zegformer_R101_bs32_30k_vit16_voc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_10k_vit16_voc.yaml 2 | MODEL: 3 | # BACKBONE: 4 | # NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "R-101.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "basic" # not used 9 | STEM_OUT_CHANNELS: 64 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 1, 1] # not used 14 | DATALOADER: 15 | FILTER_EMPTY_ANNOTATIONS: True 16 | NUM_WORKERS: 4 17 | SOLVER: 18 | IMS_PER_BATCH: 32 19 | BASE_LR: 0.0001 20 | MAX_ITER: 30000 21 | TEST: 22 | EVAL_PERIOD: 30000 -------------------------------------------------------------------------------- /configs/pascal_voc/zegformer_R101_bs32_30k_vit16_voc_gzss_eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_10k_vit16_voc.yaml 2 | MODEL: 3 | # BACKBONE: 4 | # NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "R-101.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "basic" # not used 9 | STEM_OUT_CHANNELS: 64 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 1, 1] # not used 14 | SEM_SEG_HEAD: 15 | NUM_CLASSES: 20 16 | CLIP_CLASSIFICATION: True 17 | TRAIN_CLASS_JSON: "datasets/VOCZERO/seen_classnames.json" 18 | TEST_CLASS_JSON: "datasets/VOCZERO/all_classnames.json" 19 | MASK_FORMER: 20 | GZERO_CALIBRATE: 0.1 21 | # GZERO_CALIBRATE: 0.0 22 | ENSEMBLING: True 23 | 24 | DATASETS: 25 | # TRAIN: ("ade20k_full_sem_seg_train",) 26 | TEST: ("pascal_voc_val_all_sem_seg",) -------------------------------------------------------------------------------- /configs/pascal_voc/zegformer_R50_bs32_10k_vit16_voc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-VOC-20.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "ZegFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 15 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | TRAIN_CLASS_JSON: "datasets/VOCZERO/seen_classnames.json" 15 | TEST_CLASS_JSON: "datasets/VOCZERO/seen_classnames.json" 16 | CLIP_PRETRAINED: "ViT-B/16" 17 | MASK_FORMER: 18 | TRANSFORMER_IN_FEATURE: "res5" 19 | DEEP_SUPERVISION: True 20 | NO_OBJECT_WEIGHT: 0.1 21 | DICE_WEIGHT: 1.0 22 | MASK_WEIGHT: 20.0 23 | HIDDEN_DIM: 256 24 | NUM_OBJECT_QUERIES: 100 25 | NHEADS: 8 26 | DROPOUT: 0.1 27 | DIM_FEEDFORWARD: 2048 28 | ENC_LAYERS: 0 29 | DEC_LAYERS: 6 30 | PRE_NORM: False 31 | PROMPT_ENSEMBLE_TYPE: "single" 32 | #SOLVER: 33 | # AMP: 34 | # ENABLED: True 35 | #DATALOADER: 36 | # FILTER_EMPTY_ANNOTATIONS: True 37 | # NUM_WORKERS: 4 38 | #CUDNN_BENCHMARK: True -------------------------------------------------------------------------------- /configs/pascal_voc/zegformer_R50_bs32_10k_vit16_voc_gzss_eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: zegformer_R50_bs32_10k_vit16_voc.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NUM_CLASSES: 20 6 | CLIP_CLASSIFICATION: True 7 | TRAIN_CLASS_JSON: "datasets/VOCZERO/seen_classnames.json" 8 | TEST_CLASS_JSON: "datasets/VOCZERO/all_classnames.json" 9 | MASK_FORMER: 10 | # GZERO_CALIBRATE: 0.7 11 | # GZERO_CALIBRATE: 0.1 12 | # ENSEMBLING: True 13 | ENSEMBLING: False 14 | DATASETS: 15 | # TRAIN: ("ade20k_full_sem_seg_train",) 16 | TEST: ("pascal_voc_val_all_sem_seg",) 17 | DATALOADER: 18 | # FILTER_EMPTY_ANNOTATIONS: True 19 | NUM_WORKERS: 4 20 | CUDNN_BENCHMARK: True -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Prepare Datasets for ZegFormer 2 | 3 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) 4 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). 5 | This document explains how to setup the builtin datasets so they can be used by the above APIs. 6 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, 7 | and how to add new datasets to them. 8 | 9 | ZegFormer has builtin support for a few datasets. 10 | The datasets are assumed to exist in a directory specified by the environment variable 11 | `DETECTRON2_DATASETS`. 12 | Under this directory, detectron2 will look for datasets in the structure described below, if needed. 13 | ``` 14 | $DETECTRON2_DATASETS/ 15 | coco/ 16 | ADE20K_2021_17_01/ 17 | ``` 18 | 19 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. 20 | If left unset, the default is `./datasets` relative to your current working directory. 21 | 22 | ## Prepare data for [COCO-Stuff](https://github.com/nightrome/cocostuff): 23 | 24 | ### Expected data structure 25 | 26 | ``` 27 | coco/ 28 | coco_stuff/ 29 | annotations/ 30 | train2017/ 31 | 000000144874.png 32 | ... 33 | val2017/ 34 | 000000213035.png 35 | ... 36 | images/ 37 | train2017/ 38 | 000000189148.jpg 39 | ... 40 | val2017/ 41 | 000000213547.jpg 42 | ... 43 | word_vectors/ 44 | fasttext.pkl 45 | glove.pkl 46 | word2vec.pkl 47 | # below are generated by prepare_coco_stuff_sem_seg.py 48 | split/ 49 | seen_cls.npy 50 | val_cls.npy 51 | novel_cls.npy 52 | seen_classnames.json 53 | unseen_classnames.json 54 | all_classnames.json 55 | ... 56 | annotations_detectron2/ 57 | train2017/ 58 | val2017_unseen/ 59 | ``` 60 | Get the COCO (2017) images from https://cocodataset.org/ 61 | 62 | ```bash 63 | wget http://images.cocodataset.org/zips/train2017.zip 64 | wget http://images.cocodataset.org/zips/val2017.zip 65 | ``` 66 | 67 | Get the COCO-Stuff annotation from https://github.com/nightrome/cocostuff. 68 | ```bash 69 | wget http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip 70 | ``` 71 | Unzip `train2017.zip`, `val2017`, and `stuffthingmaps_trainval2017.zip`. Then put them to the correct location listed above. 72 | 73 | [comment]: <> (Download the word vectors fasttext.pkl, glove.pkl, and word2vec.pkl from https://github.com/subhc/SPNet/tree/master/data/datasets/cocostuff/word_vectors (optional, for implement SPNet only.)) 74 | 75 | [comment]: <> (Download seen_cls.npy, val_cls.npy, novel_cls.npy from https://github.com/subhc/SPNet/tree/master/data/datasets/cocostuff/split) 76 | 77 | Split the classes into seen and unseen for training and testing. 78 | 79 | ``` 80 | python datasets/coco-stuff/create_cocostuff_class_names_json.py 81 | ``` 82 | Generate the labels for training and testing. 83 | 84 | ``` 85 | python datasets/coco-stuff/prepare_coco_stuff_sem_seg_seen.py 86 | python datasets/coco-stuff/prepare_coco_stuff_sem_seg_unseen.py 87 | python datasets/coco-stuff/prepare_coco_stuff_sem_seg_val_all.py 88 | ``` 89 | 90 | 91 | 92 | ## Prepare data for [ADE20k-Full](https://groups.csail.mit.edu/vision/datasets/ADE20K/): 93 | Download the data of ADE20k-Full from https://groups.csail.mit.edu/vision/datasets/ADE20K/request_data/ 94 | 95 | ### Expected data structure 96 | ``` 97 | ADE20K_2021_17_01/ 98 | images/ 99 | images_detectron2_freq/ 100 | annotations_detectron2_freq/ 101 | index_ade20k.pkl 102 | index_ade20k.mat 103 | objects.txt 104 | ADE20K_275_pure_class.json 105 | ADE20K_572_pure_class.json 106 | ADE20K_847_pure_class.json 107 | ``` 108 | The `ADE20K_275_pure_class.json`, `ADE20K_572_pure_class.json`, `ADE20K_847_pure_class.json`, `images_detectron2` and `annotations_detectron2` are generated by the following scripts 109 | 110 | ``` 111 | python datasets/ade20k-full-frequency-split/create_ade-frequency_json.py 112 | python datasets/ade20k-full-frequency-split/prepare_ade20k_full_frequency_all_val.py 113 | python datasets/ade20k-full-frequency-split/prepare_ade20k_full_frequency_seen.py 114 | python datasets/ade20k-full-frequency-split/prepare_ade20k_full_frequency_unseen_val.py 115 | 116 | ``` 117 | 118 | ## Prepare data for PASCAL VOC: 119 | We follow the [CaGNet](https://github.com/bcmi/CaGNet-Zero-Shot-Semantic-Segmentation) to set up the training and testing data of PASCAL VOC. 120 | We also create a copy on the [google drive](https://drive.google.com/file/d/1RvtsdXC_CdeaONcDC3j7emxcwMMG019F/view?usp=sharing) for the convenience. 121 | 122 | ### Expected data structure 123 | ``` 124 | VOCZERO/ 125 | images/ 126 | train/ 127 | 2011_003261.jpg 128 | ... 129 | val/ 130 | 2011_003145.jpg 131 | ... 132 | annotations/ 133 | train/ 134 | 2011_003255.png 135 | ... 136 | val/ 137 | 2011_003103.png 138 | ... 139 | all_classnames.json 140 | seen_classnames.json 141 | unseen_classnames.json 142 | annotations_detectron2/ 143 | train_seen 144 | 145 | ``` 146 | 147 | ``` 148 | python datasets/pascal/create_voc_class_names_json.py 149 | python datasets/pascal/prepare_pascal_voc_seen.py 150 | python datasets/pascal/prepare_pascal_voc_unseen_val.py 151 | python datasets/pascal/prepare_pascal_voc_val_all.py 152 | 153 | ``` 154 | -------------------------------------------------------------------------------- /datasets/coco/coco_stuff/split/novel_cls.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/coco/coco_stuff/split/novel_cls.npy -------------------------------------------------------------------------------- /datasets/coco/coco_stuff/split/seen_cls.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/coco/coco_stuff/split/seen_cls.npy -------------------------------------------------------------------------------- /datasets/coco/coco_stuff/split/val_cls.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/coco/coco_stuff/split/val_cls.npy -------------------------------------------------------------------------------- /datasets/pascal/create_voc_class_names_json.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | 4 | categories = [ 5 | {"name": "aeroplane", "id": 1, "trainId": 0}, 6 | {"name": "bicycle", "id": 2, "trainId": 1}, 7 | {"name": "bird", "id": 3, "trainId": 2}, 8 | {"name": "boat", "id": 4, "trainId": 3}, 9 | {"name": "bottle", "id": 5, "trainId": 4}, 10 | {"name": "bus", "id": 6, "trainId": 5}, 11 | {"name": "car", "id": 7, "trainId": 6}, 12 | {"name": "cat", "id": 8, "trainId": 7}, 13 | {"name": "chair", "id": 9, "trainId": 8}, 14 | {"name": "cow", "id": 10, "trainId": 9}, 15 | {"name": "diningtable", "id": 11, "trainId": 10}, 16 | {"name": "dog", "id": 12, "trainId": 11}, 17 | {"name": "horse", "id": 13, "trainId": 12}, 18 | {"name": "motorbike", "id": 14, "trainId": 13}, 19 | {"name": "person", "id": 15, "trainId": 14}, 20 | {"name": "potted plant", "id": 16, "trainId": 15}, 21 | {"name": "sheep", "id": 17, "trainId": 16}, 22 | {"name": "sofa", "id": 18, "trainId": 17}, 23 | {"name": "train", "id": 19, "trainId": 18}, 24 | {"name": "tvmonitor", "id": 20, "trainId": 19}] 25 | 26 | categories_seen = copy.deepcopy(categories[:15]) 27 | 28 | categories_unseen = copy.deepcopy(categories[15:]) 29 | for index, item in enumerate(categories_unseen): 30 | item["trainId"] = index 31 | 32 | with open(r'datasets/VOCZERO/all_classnames.json', 'w') as f_out: 33 | all_categories_json = [] 34 | for cat in categories: 35 | all_categories_json.append(cat["name"]) 36 | json.dump(all_categories_json, f_out) 37 | 38 | with open(r'datasets/VOCZERO/seen_classnames.json', 'w') as f_out: 39 | seen_categories_json = [] 40 | for cat in categories_seen: 41 | seen_categories_json.append(cat["name"]) 42 | json.dump(seen_categories_json, f_out) 43 | 44 | with open(r'datasets/VOCZERO/unseen_classnames.json', 'w') as f_out: 45 | unseen_categories_json = [] 46 | for cat in categories_unseen: 47 | unseen_categories_json.append(cat["name"]) 48 | json.dump(unseen_categories_json, f_out) -------------------------------------------------------------------------------- /datasets/pascal/prepare_pascal_voc_seen.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import tqdm 7 | from PIL import Image 8 | import os 9 | # This code is for pascal voc, only has 10582 images for training 10 | from shutil import copyfile 11 | categories = [ 12 | {"name": "aeroplane", "id": 1, "trainId": 0}, 13 | {"name": "bicycle", "id": 2, "trainId": 1}, 14 | {"name": "bird", "id": 3, "trainId": 2}, 15 | {"name": "boat", "id": 4, "trainId": 3}, 16 | {"name": "bottle", "id": 5, "trainId": 4}, 17 | {"name": "bus", "id": 6, "trainId": 5}, 18 | {"name": "car", "id": 7, "trainId": 6}, 19 | {"name": "cat", "id": 8, "trainId": 7}, 20 | {"name": "chair", "id": 9, "trainId": 8}, 21 | {"name": "cow", "id": 10, "trainId": 9}, 22 | {"name": "diningtable", "id": 11, "trainId": 10}, 23 | {"name": "dog", "id": 12, "trainId": 11}, 24 | {"name": "horse", "id": 13, "trainId": 12}, 25 | {"name": "motorbike", "id": 14, "trainId": 13}, 26 | {"name": "person", "id": 15, "trainId": 14}, 27 | {"name": "potted plant", "id": 16, "trainId": 15}, 28 | {"name": "sheep", "id": 17, "trainId": 16}, 29 | {"name": "sofa", "id": 18, "trainId": 17}, 30 | {"name": "train", "id": 19, "trainId": 18}, 31 | {"name": "tvmonitor", "id": 20, "trainId": 19}] 32 | 33 | categories_seen = copy.deepcopy(categories[:15]) 34 | 35 | categories_unseen = copy.deepcopy(categories[15:]) 36 | for index, item in enumerate(categories_unseen): 37 | item["trainId"] = index 38 | 39 | if __name__ == '__main__': 40 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "VOCZERO" 41 | 42 | id_map = {} 43 | for cat in categories_seen: 44 | id_map[cat["id"]] = cat["trainId"] 45 | 46 | # read train_list from the npy file of spnet 47 | train_list = np.load(r'datasets/voc12/split/train_list.npy') 48 | train_list_basename = [os.path.splitext(os.path.basename(item))[0] for item in train_list] 49 | 50 | val_list = np.load(r'datasets/voc12/split/test_list.npy') 51 | val_list_basename = [os.path.splitext(os.path.basename(item))[0] for item in val_list] 52 | 53 | for name in ["val", "train"]: 54 | annotation_dir = dataset_dir / "annotations" / name 55 | output_dir = dataset_dir / "annotations_detectron2" / f"{name}_seen" 56 | output_dir.mkdir(parents=True, exist_ok=True) 57 | 58 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 59 | basename = os.path.splitext(file.name)[0] 60 | if name == "train": 61 | if basename not in train_list_basename: 62 | continue 63 | 64 | if name == "val": 65 | if basename not in val_list_basename: 66 | continue 67 | 68 | output_file = output_dir / file.name 69 | # convert(file, output_file) 70 | lab = np.asarray(Image.open(file)) 71 | assert lab.dtype == np.uint8 72 | # img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 73 | 74 | output = np.zeros_like(lab, dtype=np.uint8) + 255 75 | for obj_id in np.unique(lab): 76 | if obj_id in id_map: 77 | output[lab == obj_id] = id_map[obj_id] 78 | 79 | Image.fromarray(output).save(output_file) -------------------------------------------------------------------------------- /datasets/pascal/prepare_pascal_voc_unseen_val.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import tqdm 7 | from PIL import Image 8 | import os 9 | 10 | from shutil import copyfile 11 | categories = [ 12 | {"name": "aeroplane", "id": 1, "trainId": 0}, 13 | {"name": "bicycle", "id": 2, "trainId": 1}, 14 | {"name": "bird", "id": 3, "trainId": 2}, 15 | {"name": "boat", "id": 4, "trainId": 3}, 16 | {"name": "bottle", "id": 5, "trainId": 4}, 17 | {"name": "bus", "id": 6, "trainId": 5}, 18 | {"name": "car", "id": 7, "trainId": 6}, 19 | {"name": "cat", "id": 8, "trainId": 7}, 20 | {"name": "chair", "id": 9, "trainId": 8}, 21 | {"name": "cow", "id": 10, "trainId": 9}, 22 | {"name": "diningtable", "id": 11, "trainId": 10}, 23 | {"name": "dog", "id": 12, "trainId": 11}, 24 | {"name": "horse", "id": 13, "trainId": 12}, 25 | {"name": "motorbike", "id": 14, "trainId": 13}, 26 | {"name": "person", "id": 15, "trainId": 14}, 27 | {"name": "potted plant", "id": 16, "trainId": 15}, 28 | {"name": "sheep", "id": 17, "trainId": 16}, 29 | {"name": "sofa", "id": 18, "trainId": 17}, 30 | {"name": "train", "id": 19, "trainId": 18}, 31 | {"name": "tvmonitor", "id": 20, "trainId": 19}] 32 | 33 | categories_seen = copy.deepcopy(categories[:15]) 34 | 35 | categories_unseen = copy.deepcopy(categories[15:]) 36 | for index, item in enumerate(categories_unseen): 37 | item["trainId"] = index 38 | 39 | if __name__ == '__main__': 40 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "VOCZERO" 41 | 42 | id_map = {} 43 | for cat in categories_unseen: 44 | id_map[cat["id"]] = cat["trainId"] 45 | 46 | # for name in ["val", "train"]: 47 | for name in ["val",]: 48 | 49 | annotation_dir = dataset_dir / "annotations" / name 50 | output_dir = dataset_dir / "annotations_detectron2" / f"{name}_unseen" 51 | output_dir.mkdir(parents=True, exist_ok=True) 52 | 53 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 54 | 55 | output_file = output_dir / file.name 56 | # convert(file, output_file) 57 | lab = np.asarray(Image.open(file)) 58 | assert lab.dtype == np.uint8 59 | # img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 60 | 61 | output = np.zeros_like(lab, dtype=np.uint8) + 255 62 | for obj_id in np.unique(lab): 63 | if obj_id in id_map: 64 | output[lab == obj_id] = id_map[obj_id] 65 | 66 | Image.fromarray(output).save(output_file) -------------------------------------------------------------------------------- /datasets/pascal/prepare_pascal_voc_val_all.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import tqdm 7 | from PIL import Image 8 | import os 9 | 10 | from shutil import copyfile 11 | categories = [ 12 | {"name": "aeroplane", "id": 1, "trainId": 0}, 13 | {"name": "bicycle", "id": 2, "trainId": 1}, 14 | {"name": "bird", "id": 3, "trainId": 2}, 15 | {"name": "boat", "id": 4, "trainId": 3}, 16 | {"name": "bottle", "id": 5, "trainId": 4}, 17 | {"name": "bus", "id": 6, "trainId": 5}, 18 | {"name": "car", "id": 7, "trainId": 6}, 19 | {"name": "cat", "id": 8, "trainId": 7}, 20 | {"name": "chair", "id": 9, "trainId": 8}, 21 | {"name": "cow", "id": 10, "trainId": 9}, 22 | {"name": "diningtable", "id": 11, "trainId": 10}, 23 | {"name": "dog", "id": 12, "trainId": 11}, 24 | {"name": "horse", "id": 13, "trainId": 12}, 25 | {"name": "motorbike", "id": 14, "trainId": 13}, 26 | {"name": "person", "id": 15, "trainId": 14}, 27 | {"name": "potted plant", "id": 16, "trainId": 15}, 28 | {"name": "sheep", "id": 17, "trainId": 16}, 29 | {"name": "sofa", "id": 18, "trainId": 17}, 30 | {"name": "train", "id": 19, "trainId": 18}, 31 | {"name": "tvmonitor", "id": 20, "trainId": 19}] 32 | 33 | categories_seen = copy.deepcopy(categories[:15]) 34 | 35 | categories_unseen = copy.deepcopy(categories[15:]) 36 | for index, item in enumerate(categories_unseen): 37 | item["trainId"] = index 38 | 39 | if __name__ == '__main__': 40 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "VOCZERO" 41 | 42 | id_map = {} 43 | for cat in categories: 44 | id_map[cat["id"]] = cat["trainId"] 45 | 46 | # for name in ["val", "train"]: 47 | for name in ["val",]: 48 | 49 | annotation_dir = dataset_dir / "annotations" / name 50 | output_dir = dataset_dir / "annotations_detectron2" / f"{name}_all" 51 | output_dir.mkdir(parents=True, exist_ok=True) 52 | 53 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 54 | 55 | output_file = output_dir / file.name 56 | # convert(file, output_file) 57 | lab = np.asarray(Image.open(file)) 58 | assert lab.dtype == np.uint8 59 | # img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 60 | 61 | output = np.zeros_like(lab, dtype=np.uint8) + 255 62 | for obj_id in np.unique(lab): 63 | if obj_id in id_map: 64 | output[lab == obj_id] = id_map[obj_id] 65 | 66 | Image.fromarray(output).save(output_file) -------------------------------------------------------------------------------- /datasets/voc12/split/novel_cls.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/voc12/split/novel_cls.npy -------------------------------------------------------------------------------- /datasets/voc12/split/seen_cls.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/voc12/split/seen_cls.npy -------------------------------------------------------------------------------- /datasets/voc12/split/test_list.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/voc12/split/test_list.npy -------------------------------------------------------------------------------- /datasets/voc12/split/train_list.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/datasets/voc12/split/train_list.npy -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py 3 | import argparse 4 | import glob 5 | import multiprocessing as mp 6 | import os 7 | 8 | # fmt: off 9 | import sys 10 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 11 | # fmt: on 12 | 13 | import tempfile 14 | import time 15 | import warnings 16 | 17 | import cv2 18 | import numpy as np 19 | import tqdm 20 | 21 | from detectron2.config import get_cfg 22 | from detectron2.data.detection_utils import read_image 23 | from detectron2.projects.deeplab import add_deeplab_config 24 | from detectron2.utils.logger import setup_logger 25 | 26 | from mask_former import add_mask_former_config 27 | from predictor import VisualizationDemo 28 | 29 | 30 | # constants 31 | WINDOW_NAME = "MaskFormer demo" 32 | 33 | 34 | def setup_cfg(args): 35 | # load config from file and command-line arguments 36 | cfg = get_cfg() 37 | add_deeplab_config(cfg) 38 | add_mask_former_config(cfg) 39 | cfg.merge_from_file(args.config_file) 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | return cfg 43 | 44 | 45 | def get_parser(): 46 | parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") 47 | parser.add_argument( 48 | "--config-file", 49 | default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml", 50 | metavar="FILE", 51 | help="path to config file", 52 | ) 53 | parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") 54 | parser.add_argument("--video-input", help="Path to video file.") 55 | parser.add_argument( 56 | "--input", 57 | nargs="+", 58 | help="A list of space separated input images; " 59 | "or a single glob pattern such as 'directory/*.jpg'", 60 | ) 61 | parser.add_argument( 62 | "--output", 63 | help="A file or directory to save output visualizations. " 64 | "If not given, will show output in an OpenCV window.", 65 | ) 66 | 67 | parser.add_argument( 68 | "--confidence-threshold", 69 | type=float, 70 | default=0.5, 71 | help="Minimum score for instance predictions to be shown", 72 | ) 73 | parser.add_argument( 74 | "--opts", 75 | help="Modify config options using the command-line 'KEY VALUE' pairs", 76 | default=[], 77 | nargs=argparse.REMAINDER, 78 | ) 79 | return parser 80 | 81 | 82 | def test_opencv_video_format(codec, file_ext): 83 | with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: 84 | filename = os.path.join(dir, "test_file" + file_ext) 85 | writer = cv2.VideoWriter( 86 | filename=filename, 87 | fourcc=cv2.VideoWriter_fourcc(*codec), 88 | fps=float(30), 89 | frameSize=(10, 10), 90 | isColor=True, 91 | ) 92 | [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] 93 | writer.release() 94 | if os.path.isfile(filename): 95 | return True 96 | return False 97 | 98 | 99 | if __name__ == "__main__": 100 | mp.set_start_method("spawn", force=True) 101 | args = get_parser().parse_args() 102 | setup_logger(name="fvcore") 103 | logger = setup_logger() 104 | logger.info("Arguments: " + str(args)) 105 | 106 | cfg = setup_cfg(args) 107 | 108 | demo = VisualizationDemo(cfg) 109 | 110 | if args.input: 111 | if len(args.input) == 1: 112 | args.input = glob.glob(os.path.expanduser(args.input[0])) 113 | assert args.input, "The input path(s) was not found" 114 | for path in tqdm.tqdm(args.input, disable=not args.output): 115 | # use PIL, to be consistent with evaluation 116 | img = read_image(path, format="BGR") 117 | start_time = time.time() 118 | predictions, visualized_output = demo.run_on_image(img) 119 | logger.info( 120 | "{}: {} in {:.2f}s".format( 121 | path, 122 | "detected {} instances".format(len(predictions["instances"])) 123 | if "instances" in predictions 124 | else "finished", 125 | time.time() - start_time, 126 | ) 127 | ) 128 | 129 | if args.output: 130 | if os.path.isdir(args.output): 131 | assert os.path.isdir(args.output), args.output 132 | out_filename = os.path.join(args.output, os.path.basename(path)) 133 | else: 134 | assert len(args.input) == 1, "Please specify a directory with args.output" 135 | out_filename = args.output 136 | visualized_output.save(out_filename) 137 | else: 138 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 139 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) 140 | if cv2.waitKey(0) == 27: 141 | break # esc to quit 142 | elif args.webcam: 143 | assert args.input is None, "Cannot have both --input and --webcam!" 144 | assert args.output is None, "output not yet supported with --webcam!" 145 | cam = cv2.VideoCapture(0) 146 | for vis in tqdm.tqdm(demo.run_on_video(cam)): 147 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 148 | cv2.imshow(WINDOW_NAME, vis) 149 | if cv2.waitKey(1) == 27: 150 | break # esc to quit 151 | cam.release() 152 | cv2.destroyAllWindows() 153 | elif args.video_input: 154 | video = cv2.VideoCapture(args.video_input) 155 | width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 156 | height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 157 | frames_per_second = video.get(cv2.CAP_PROP_FPS) 158 | num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 159 | basename = os.path.basename(args.video_input) 160 | codec, file_ext = ( 161 | ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") 162 | ) 163 | if codec == ".mp4v": 164 | warnings.warn("x264 codec not available, switching to mp4v") 165 | if args.output: 166 | if os.path.isdir(args.output): 167 | output_fname = os.path.join(args.output, basename) 168 | output_fname = os.path.splitext(output_fname)[0] + file_ext 169 | else: 170 | output_fname = args.output 171 | assert not os.path.isfile(output_fname), output_fname 172 | output_file = cv2.VideoWriter( 173 | filename=output_fname, 174 | # some installation of opencv may not support x264 (due to its license), 175 | # you can try other format (e.g. MPEG) 176 | fourcc=cv2.VideoWriter_fourcc(*codec), 177 | fps=float(frames_per_second), 178 | frameSize=(width, height), 179 | isColor=True, 180 | ) 181 | assert os.path.isfile(args.video_input) 182 | for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): 183 | if args.output: 184 | output_file.write(vis_frame) 185 | else: 186 | cv2.namedWindow(basename, cv2.WINDOW_NORMAL) 187 | cv2.imshow(basename, vis_frame) 188 | if cv2.waitKey(1) == 27: 189 | break # esc to quit 190 | video.release() 191 | if args.output: 192 | output_file.release() 193 | else: 194 | cv2.destroyAllWindows() 195 | -------------------------------------------------------------------------------- /demo/demo_visual_gt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py 3 | import argparse 4 | import glob 5 | import multiprocessing as mp 6 | import os 7 | 8 | # fmt: off 9 | import sys 10 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 11 | # fmt: on 12 | 13 | import tempfile 14 | import time 15 | import warnings 16 | 17 | import cv2 18 | import numpy as np 19 | import tqdm 20 | 21 | from detectron2.config import get_cfg 22 | from detectron2.data.detection_utils import read_image 23 | from detectron2.projects.deeplab import add_deeplab_config 24 | from detectron2.utils.logger import setup_logger 25 | 26 | from mask_former import add_mask_former_config 27 | # from predictor import VisualizationDemo 28 | from visualizer import VisualizationGt 29 | from PIL import Image 30 | 31 | # constants 32 | WINDOW_NAME = "MaskFormer demo" 33 | 34 | 35 | def setup_cfg(args): 36 | # load config from file and command-line arguments 37 | cfg = get_cfg() 38 | add_deeplab_config(cfg) 39 | add_mask_former_config(cfg) 40 | cfg.merge_from_file(args.config_file) 41 | cfg.merge_from_list(args.opts) 42 | cfg.freeze() 43 | return cfg 44 | 45 | 46 | def get_parser(): 47 | parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") 48 | parser.add_argument( 49 | "--config-file", 50 | default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml", 51 | metavar="FILE", 52 | help="path to config file", 53 | ) 54 | parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") 55 | parser.add_argument("--video-input", help="Path to video file.") 56 | parser.add_argument( 57 | "--input", 58 | nargs="+", 59 | help="A list of space separated input images; " 60 | "or a single glob pattern such as 'directory/*.jpg'", 61 | ) 62 | # parser.add_argument( 63 | # "--gt", 64 | # nargs="+", 65 | # help="A list of space seperated ground truth images;" 66 | # "or a single glob pattern such as 'directory/*.png'" 67 | # ) 68 | parser.add_argument( 69 | "--gt", 70 | # type="str", 71 | help="ground truth path of segmentation" 72 | ) 73 | parser.add_argument( 74 | "--output", 75 | help="A file or directory to save output visualizations. " 76 | "If not given, will show output in an OpenCV window.", 77 | ) 78 | 79 | parser.add_argument( 80 | "--confidence-threshold", 81 | type=float, 82 | default=0.5, 83 | help="Minimum score for instance predictions to be shown", 84 | ) 85 | parser.add_argument( 86 | "--opts", 87 | help="Modify config options using the command-line 'KEY VALUE' pairs", 88 | default=[], 89 | nargs=argparse.REMAINDER, 90 | ) 91 | return parser 92 | 93 | 94 | def test_opencv_video_format(codec, file_ext): 95 | with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: 96 | filename = os.path.join(dir, "test_file" + file_ext) 97 | writer = cv2.VideoWriter( 98 | filename=filename, 99 | fourcc=cv2.VideoWriter_fourcc(*codec), 100 | fps=float(30), 101 | frameSize=(10, 10), 102 | isColor=True, 103 | ) 104 | [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] 105 | writer.release() 106 | if os.path.isfile(filename): 107 | return True 108 | return False 109 | 110 | 111 | if __name__ == "__main__": 112 | mp.set_start_method("spawn", force=True) 113 | args = get_parser().parse_args() 114 | setup_logger(name="fvcore") 115 | logger = setup_logger() 116 | logger.info("Arguments: " + str(args)) 117 | 118 | cfg = setup_cfg(args) 119 | 120 | demo = VisualizationGt(cfg) 121 | gt_path = args.gt 122 | if args.input: 123 | if len(args.input) == 1: 124 | args.input = glob.glob(os.path.expanduser(args.input[0])) 125 | assert args.input, "The input path(s) was not found" 126 | for path in tqdm.tqdm(args.input, disable=not args.output): 127 | # use PIL, to be consistent with evaluation 128 | img = read_image(path, format="BGR") 129 | start_time = time.time() 130 | predictions = {} 131 | gt_file = os.path.join(gt_path, os.path.splitext(os.path.basename(path))[0] + '.png') 132 | # import pdb; pdb.set_trace() 133 | predictions['sem_seg'] = np.asarray(Image.open(gt_file)) 134 | predictions, visualized_output = demo.run_on_image(img, predictions) 135 | logger.info( 136 | "{}: {} in {:.2f}s".format( 137 | path, 138 | "detected {} instances".format(len(predictions["instances"])) 139 | if "instances" in predictions 140 | else "finished", 141 | time.time() - start_time, 142 | ) 143 | ) 144 | 145 | if args.output: 146 | if os.path.isdir(args.output): 147 | assert os.path.isdir(args.output), args.output 148 | out_filename = os.path.join(args.output, os.path.basename(path)) 149 | else: 150 | assert len(args.input) == 1, "Please specify a directory with args.output" 151 | out_filename = args.output 152 | visualized_output.save(out_filename) 153 | else: 154 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 155 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) 156 | if cv2.waitKey(0) == 27: 157 | break # esc to quit 158 | elif args.webcam: 159 | assert args.input is None, "Cannot have both --input and --webcam!" 160 | assert args.output is None, "output not yet supported with --webcam!" 161 | cam = cv2.VideoCapture(0) 162 | for vis in tqdm.tqdm(demo.run_on_video(cam)): 163 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 164 | cv2.imshow(WINDOW_NAME, vis) 165 | if cv2.waitKey(1) == 27: 166 | break # esc to quit 167 | cam.release() 168 | cv2.destroyAllWindows() 169 | elif args.video_input: 170 | video = cv2.VideoCapture(args.video_input) 171 | width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 172 | height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 173 | frames_per_second = video.get(cv2.CAP_PROP_FPS) 174 | num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 175 | basename = os.path.basename(args.video_input) 176 | codec, file_ext = ( 177 | ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") 178 | ) 179 | if codec == ".mp4v": 180 | warnings.warn("x264 codec not available, switching to mp4v") 181 | if args.output: 182 | if os.path.isdir(args.output): 183 | output_fname = os.path.join(args.output, basename) 184 | output_fname = os.path.splitext(output_fname)[0] + file_ext 185 | else: 186 | output_fname = args.output 187 | assert not os.path.isfile(output_fname), output_fname 188 | output_file = cv2.VideoWriter( 189 | filename=output_fname, 190 | # some installation of opencv may not support x264 (due to its license), 191 | # you can try other format (e.g. MPEG) 192 | fourcc=cv2.VideoWriter_fourcc(*codec), 193 | fps=float(frames_per_second), 194 | frameSize=(width, height), 195 | isColor=True, 196 | ) 197 | assert os.path.isfile(args.video_input) 198 | for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): 199 | if args.output: 200 | output_file.write(vis_frame) 201 | else: 202 | cv2.namedWindow(basename, cv2.WINDOW_NORMAL) 203 | cv2.imshow(basename, vis_frame) 204 | if cv2.waitKey(1) == 27: 205 | break # esc to quit 206 | video.release() 207 | if args.output: 208 | output_file.release() 209 | else: 210 | cv2.destroyAllWindows() 211 | -------------------------------------------------------------------------------- /demo/demo_visual_gt_adefull.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py 3 | import argparse 4 | import glob 5 | import multiprocessing as mp 6 | import os 7 | 8 | # fmt: off 9 | import sys 10 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 11 | # fmt: on 12 | 13 | import tempfile 14 | import time 15 | import warnings 16 | 17 | import cv2 18 | import numpy as np 19 | import tqdm 20 | 21 | from detectron2.config import get_cfg 22 | from detectron2.data.detection_utils import read_image 23 | from detectron2.projects.deeplab import add_deeplab_config 24 | from detectron2.utils.logger import setup_logger 25 | 26 | from mask_former import add_mask_former_config 27 | # from predictor import VisualizationDemo 28 | from visualizer import VisualizationGt 29 | from PIL import Image 30 | 31 | # constants 32 | WINDOW_NAME = "MaskFormer demo" 33 | 34 | 35 | def setup_cfg(args): 36 | # load config from file and command-line arguments 37 | cfg = get_cfg() 38 | add_deeplab_config(cfg) 39 | add_mask_former_config(cfg) 40 | cfg.merge_from_file(args.config_file) 41 | cfg.merge_from_list(args.opts) 42 | cfg.freeze() 43 | return cfg 44 | 45 | 46 | def get_parser(): 47 | parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") 48 | parser.add_argument( 49 | "--config-file", 50 | default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml", 51 | metavar="FILE", 52 | help="path to config file", 53 | ) 54 | parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") 55 | parser.add_argument("--video-input", help="Path to video file.") 56 | parser.add_argument( 57 | "--input", 58 | nargs="+", 59 | help="A list of space separated input images; " 60 | "or a single glob pattern such as 'directory/*.jpg'", 61 | ) 62 | # parser.add_argument( 63 | # "--gt", 64 | # nargs="+", 65 | # help="A list of space seperated ground truth images;" 66 | # "or a single glob pattern such as 'directory/*.png'" 67 | # ) 68 | parser.add_argument( 69 | "--gt", 70 | # type="str", 71 | help="ground truth path of segmentation" 72 | ) 73 | parser.add_argument( 74 | "--output", 75 | help="A file or directory to save output visualizations. " 76 | "If not given, will show output in an OpenCV window.", 77 | ) 78 | 79 | parser.add_argument( 80 | "--confidence-threshold", 81 | type=float, 82 | default=0.5, 83 | help="Minimum score for instance predictions to be shown", 84 | ) 85 | parser.add_argument( 86 | "--opts", 87 | help="Modify config options using the command-line 'KEY VALUE' pairs", 88 | default=[], 89 | nargs=argparse.REMAINDER, 90 | ) 91 | return parser 92 | 93 | 94 | def test_opencv_video_format(codec, file_ext): 95 | with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: 96 | filename = os.path.join(dir, "test_file" + file_ext) 97 | writer = cv2.VideoWriter( 98 | filename=filename, 99 | fourcc=cv2.VideoWriter_fourcc(*codec), 100 | fps=float(30), 101 | frameSize=(10, 10), 102 | isColor=True, 103 | ) 104 | [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] 105 | writer.release() 106 | if os.path.isfile(filename): 107 | return True 108 | return False 109 | 110 | 111 | if __name__ == "__main__": 112 | mp.set_start_method("spawn", force=True) 113 | args = get_parser().parse_args() 114 | setup_logger(name="fvcore") 115 | logger = setup_logger() 116 | logger.info("Arguments: " + str(args)) 117 | 118 | cfg = setup_cfg(args) 119 | 120 | demo = VisualizationGt(cfg) 121 | gt_path = args.gt 122 | if args.input: 123 | if len(args.input) == 1: 124 | args.input = glob.glob(os.path.expanduser(args.input[0])) 125 | assert args.input, "The input path(s) was not found" 126 | for path in tqdm.tqdm(args.input, disable=not args.output): 127 | # use PIL, to be consistent with evaluation 128 | img = read_image(path, format="BGR") 129 | start_time = time.time() 130 | predictions = {} 131 | gt_file = os.path.join(gt_path, os.path.splitext(os.path.basename(path))[0] + '.tif') 132 | # import pdb; pdb.set_trace() 133 | predictions['sem_seg'] = np.asarray(Image.open(gt_file)) 134 | predictions, visualized_output = demo.run_on_image(img, predictions) 135 | logger.info( 136 | "{}: {} in {:.2f}s".format( 137 | path, 138 | "detected {} instances".format(len(predictions["instances"])) 139 | if "instances" in predictions 140 | else "finished", 141 | time.time() - start_time, 142 | ) 143 | ) 144 | 145 | if args.output: 146 | if os.path.isdir(args.output): 147 | assert os.path.isdir(args.output), args.output 148 | out_filename = os.path.join(args.output, os.path.basename(path)) 149 | else: 150 | assert len(args.input) == 1, "Please specify a directory with args.output" 151 | out_filename = args.output 152 | visualized_output.save(out_filename) 153 | else: 154 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 155 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) 156 | if cv2.waitKey(0) == 27: 157 | break # esc to quit 158 | elif args.webcam: 159 | assert args.input is None, "Cannot have both --input and --webcam!" 160 | assert args.output is None, "output not yet supported with --webcam!" 161 | cam = cv2.VideoCapture(0) 162 | for vis in tqdm.tqdm(demo.run_on_video(cam)): 163 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 164 | cv2.imshow(WINDOW_NAME, vis) 165 | if cv2.waitKey(1) == 27: 166 | break # esc to quit 167 | cam.release() 168 | cv2.destroyAllWindows() 169 | elif args.video_input: 170 | video = cv2.VideoCapture(args.video_input) 171 | width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 172 | height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 173 | frames_per_second = video.get(cv2.CAP_PROP_FPS) 174 | num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 175 | basename = os.path.basename(args.video_input) 176 | codec, file_ext = ( 177 | ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") 178 | ) 179 | if codec == ".mp4v": 180 | warnings.warn("x264 codec not available, switching to mp4v") 181 | if args.output: 182 | if os.path.isdir(args.output): 183 | output_fname = os.path.join(args.output, basename) 184 | output_fname = os.path.splitext(output_fname)[0] + file_ext 185 | else: 186 | output_fname = args.output 187 | assert not os.path.isfile(output_fname), output_fname 188 | output_file = cv2.VideoWriter( 189 | filename=output_fname, 190 | # some installation of opencv may not support x264 (due to its license), 191 | # you can try other format (e.g. MPEG) 192 | fourcc=cv2.VideoWriter_fourcc(*codec), 193 | fps=float(frames_per_second), 194 | frameSize=(width, height), 195 | isColor=True, 196 | ) 197 | assert os.path.isfile(args.video_input) 198 | for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): 199 | if args.output: 200 | output_file.write(vis_frame) 201 | else: 202 | cv2.namedWindow(basename, cv2.WINDOW_NORMAL) 203 | cv2.imshow(basename, vis_frame) 204 | if cv2.waitKey(1) == 27: 205 | break # esc to quit 206 | video.release() 207 | if args.output: 208 | output_file.release() 209 | else: 210 | cv2.destroyAllWindows() 211 | -------------------------------------------------------------------------------- /demo/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py 3 | import atexit 4 | import bisect 5 | import multiprocessing as mp 6 | from collections import deque 7 | 8 | import cv2 9 | import torch 10 | 11 | from detectron2.data import MetadataCatalog 12 | from detectron2.engine.defaults import DefaultPredictor 13 | from detectron2.utils.video_visualizer import VideoVisualizer 14 | from detectron2.utils.visualizer import ColorMode, Visualizer 15 | 16 | 17 | class VisualizationDemo(object): 18 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): 19 | """ 20 | Args: 21 | cfg (CfgNode): 22 | instance_mode (ColorMode): 23 | parallel (bool): whether to run the model in different processes from visualization. 24 | Useful since the visualization logic can be slow. 25 | """ 26 | self.metadata = MetadataCatalog.get( 27 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" 28 | ) 29 | self.cpu_device = torch.device("cpu") 30 | self.instance_mode = instance_mode 31 | 32 | self.parallel = parallel 33 | if parallel: 34 | num_gpu = torch.cuda.device_count() 35 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) 36 | else: 37 | self.predictor = DefaultPredictor(cfg) 38 | 39 | def run_on_image(self, image): 40 | """ 41 | Args: 42 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 43 | This is the format used by OpenCV. 44 | Returns: 45 | predictions (dict): the output of the model. 46 | vis_output (VisImage): the visualized image output. 47 | """ 48 | vis_output = None 49 | predictions = self.predictor(image) 50 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 51 | image = image[:, :, ::-1] 52 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) 53 | if "panoptic_seg" in predictions: 54 | panoptic_seg, segments_info = predictions["panoptic_seg"] 55 | vis_output = visualizer.draw_panoptic_seg_predictions( 56 | panoptic_seg.to(self.cpu_device), segments_info 57 | ) 58 | else: 59 | if "sem_seg" in predictions: 60 | vis_output = visualizer.draw_sem_seg( 61 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 62 | ) 63 | if "instances" in predictions: 64 | instances = predictions["instances"].to(self.cpu_device) 65 | vis_output = visualizer.draw_instance_predictions(predictions=instances) 66 | 67 | return predictions, vis_output 68 | 69 | def _frame_from_video(self, video): 70 | while video.isOpened(): 71 | success, frame = video.read() 72 | if success: 73 | yield frame 74 | else: 75 | break 76 | 77 | def run_on_video(self, video): 78 | """ 79 | Visualizes predictions on frames of the input video. 80 | Args: 81 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be 82 | either a webcam or a video file. 83 | Yields: 84 | ndarray: BGR visualizations of each video frame. 85 | """ 86 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) 87 | 88 | def process_predictions(frame, predictions): 89 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 90 | if "panoptic_seg" in predictions: 91 | panoptic_seg, segments_info = predictions["panoptic_seg"] 92 | vis_frame = video_visualizer.draw_panoptic_seg_predictions( 93 | frame, panoptic_seg.to(self.cpu_device), segments_info 94 | ) 95 | elif "instances" in predictions: 96 | predictions = predictions["instances"].to(self.cpu_device) 97 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) 98 | elif "sem_seg" in predictions: 99 | vis_frame = video_visualizer.draw_sem_seg( 100 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 101 | ) 102 | 103 | # Converts Matplotlib RGB format to OpenCV BGR format 104 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) 105 | return vis_frame 106 | 107 | frame_gen = self._frame_from_video(video) 108 | if self.parallel: 109 | buffer_size = self.predictor.default_buffer_size 110 | 111 | frame_data = deque() 112 | 113 | for cnt, frame in enumerate(frame_gen): 114 | frame_data.append(frame) 115 | self.predictor.put(frame) 116 | 117 | if cnt >= buffer_size: 118 | frame = frame_data.popleft() 119 | predictions = self.predictor.get() 120 | yield process_predictions(frame, predictions) 121 | 122 | while len(frame_data): 123 | frame = frame_data.popleft() 124 | predictions = self.predictor.get() 125 | yield process_predictions(frame, predictions) 126 | else: 127 | for frame in frame_gen: 128 | yield process_predictions(frame, self.predictor(frame)) 129 | 130 | 131 | class AsyncPredictor: 132 | """ 133 | A predictor that runs the model asynchronously, possibly on >1 GPUs. 134 | Because rendering the visualization takes considerably amount of time, 135 | this helps improve throughput a little bit when rendering videos. 136 | """ 137 | 138 | class _StopToken: 139 | pass 140 | 141 | class _PredictWorker(mp.Process): 142 | def __init__(self, cfg, task_queue, result_queue): 143 | self.cfg = cfg 144 | self.task_queue = task_queue 145 | self.result_queue = result_queue 146 | super().__init__() 147 | 148 | def run(self): 149 | predictor = DefaultPredictor(self.cfg) 150 | 151 | while True: 152 | task = self.task_queue.get() 153 | if isinstance(task, AsyncPredictor._StopToken): 154 | break 155 | idx, data = task 156 | result = predictor(data) 157 | self.result_queue.put((idx, result)) 158 | 159 | def __init__(self, cfg, num_gpus: int = 1): 160 | """ 161 | Args: 162 | cfg (CfgNode): 163 | num_gpus (int): if 0, will run on CPU 164 | """ 165 | num_workers = max(num_gpus, 1) 166 | self.task_queue = mp.Queue(maxsize=num_workers * 3) 167 | self.result_queue = mp.Queue(maxsize=num_workers * 3) 168 | self.procs = [] 169 | for gpuid in range(max(num_gpus, 1)): 170 | cfg = cfg.clone() 171 | cfg.defrost() 172 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" 173 | self.procs.append( 174 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) 175 | ) 176 | 177 | self.put_idx = 0 178 | self.get_idx = 0 179 | self.result_rank = [] 180 | self.result_data = [] 181 | 182 | for p in self.procs: 183 | p.start() 184 | atexit.register(self.shutdown) 185 | 186 | def put(self, image): 187 | self.put_idx += 1 188 | self.task_queue.put((self.put_idx, image)) 189 | 190 | def get(self): 191 | self.get_idx += 1 # the index needed for this request 192 | if len(self.result_rank) and self.result_rank[0] == self.get_idx: 193 | res = self.result_data[0] 194 | del self.result_data[0], self.result_rank[0] 195 | return res 196 | 197 | while True: 198 | # make sure the results are returned in the correct order 199 | idx, res = self.result_queue.get() 200 | if idx == self.get_idx: 201 | return res 202 | insert = bisect.bisect(self.result_rank, idx) 203 | self.result_rank.insert(insert, idx) 204 | self.result_data.insert(insert, res) 205 | 206 | def __len__(self): 207 | return self.put_idx - self.get_idx 208 | 209 | def __call__(self, image): 210 | self.put(image) 211 | return self.get() 212 | 213 | def shutdown(self): 214 | for _ in self.procs: 215 | self.task_queue.put(AsyncPredictor._StopToken()) 216 | 217 | @property 218 | def default_buffer_size(self): 219 | return len(self.procs) * 5 220 | -------------------------------------------------------------------------------- /demo/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py 3 | import atexit 4 | import bisect 5 | import multiprocessing as mp 6 | from collections import deque 7 | 8 | import cv2 9 | import torch 10 | 11 | from detectron2.data import MetadataCatalog 12 | from detectron2.engine.defaults import DefaultPredictor 13 | from detectron2.utils.video_visualizer import VideoVisualizer 14 | from detectron2.utils.visualizer import ColorMode, Visualizer 15 | 16 | 17 | class VisualizationGt(object): 18 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): 19 | """ 20 | Args: 21 | cfg (CfgNode): 22 | instance_mode (ColorMode): 23 | parallel (bool): whether to run the model in different processes from visualization. 24 | Useful since the visualization logic can be slow. 25 | """ 26 | self.metadata = MetadataCatalog.get( 27 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" 28 | ) 29 | self.cpu_device = torch.device("cpu") 30 | self.instance_mode = instance_mode 31 | 32 | self.parallel = parallel 33 | if parallel: 34 | num_gpu = torch.cuda.device_count() 35 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) 36 | else: 37 | self.predictor = DefaultPredictor(cfg) 38 | 39 | def run_on_image(self, image, predictions): 40 | """ 41 | Args: 42 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 43 | This is the format used by OpenCV. 44 | Returns: 45 | predictions (dict): the output of the model. 46 | vis_output (VisImage): the visualized image output. 47 | """ 48 | vis_output = None 49 | # predictions = self.predictor(image) 50 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 51 | image = image[:, :, ::-1] 52 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) 53 | if "panoptic_seg" in predictions: 54 | panoptic_seg, segments_info = predictions["panoptic_seg"] 55 | vis_output = visualizer.draw_panoptic_seg_predictions( 56 | panoptic_seg.to(self.cpu_device), segments_info 57 | ) 58 | else: 59 | if "sem_seg" in predictions: 60 | vis_output = visualizer.draw_sem_seg( 61 | predictions["sem_seg"] 62 | ) 63 | if "instances" in predictions: 64 | instances = predictions["instances"].to(self.cpu_device) 65 | vis_output = visualizer.draw_instance_predictions(predictions=instances) 66 | 67 | return predictions, vis_output 68 | 69 | def _frame_from_video(self, video): 70 | while video.isOpened(): 71 | success, frame = video.read() 72 | if success: 73 | yield frame 74 | else: 75 | break 76 | 77 | def run_on_video(self, video): 78 | """ 79 | Visualizes predictions on frames of the input video. 80 | Args: 81 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be 82 | either a webcam or a video file. 83 | Yields: 84 | ndarray: BGR visualizations of each video frame. 85 | """ 86 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) 87 | 88 | def process_predictions(frame, predictions): 89 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 90 | if "panoptic_seg" in predictions: 91 | panoptic_seg, segments_info = predictions["panoptic_seg"] 92 | vis_frame = video_visualizer.draw_panoptic_seg_predictions( 93 | frame, panoptic_seg.to(self.cpu_device), segments_info 94 | ) 95 | elif "instances" in predictions: 96 | predictions = predictions["instances"].to(self.cpu_device) 97 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) 98 | elif "sem_seg" in predictions: 99 | vis_frame = video_visualizer.draw_sem_seg( 100 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 101 | ) 102 | 103 | # Converts Matplotlib RGB format to OpenCV BGR format 104 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) 105 | return vis_frame 106 | 107 | frame_gen = self._frame_from_video(video) 108 | if self.parallel: 109 | buffer_size = self.predictor.default_buffer_size 110 | 111 | frame_data = deque() 112 | 113 | for cnt, frame in enumerate(frame_gen): 114 | frame_data.append(frame) 115 | self.predictor.put(frame) 116 | 117 | if cnt >= buffer_size: 118 | frame = frame_data.popleft() 119 | predictions = self.predictor.get() 120 | yield process_predictions(frame, predictions) 121 | 122 | while len(frame_data): 123 | frame = frame_data.popleft() 124 | predictions = self.predictor.get() 125 | yield process_predictions(frame, predictions) 126 | else: 127 | for frame in frame_gen: 128 | yield process_predictions(frame, self.predictor(frame)) 129 | 130 | 131 | class AsyncPredictor: 132 | """ 133 | A predictor that runs the model asynchronously, possibly on >1 GPUs. 134 | Because rendering the visualization takes considerably amount of time, 135 | this helps improve throughput a little bit when rendering videos. 136 | """ 137 | 138 | class _StopToken: 139 | pass 140 | 141 | class _PredictWorker(mp.Process): 142 | def __init__(self, cfg, task_queue, result_queue): 143 | self.cfg = cfg 144 | self.task_queue = task_queue 145 | self.result_queue = result_queue 146 | super().__init__() 147 | 148 | def run(self): 149 | predictor = DefaultPredictor(self.cfg) 150 | 151 | while True: 152 | task = self.task_queue.get() 153 | if isinstance(task, AsyncPredictor._StopToken): 154 | break 155 | idx, data = task 156 | result = predictor(data) 157 | self.result_queue.put((idx, result)) 158 | 159 | def __init__(self, cfg, num_gpus: int = 1): 160 | """ 161 | Args: 162 | cfg (CfgNode): 163 | num_gpus (int): if 0, will run on CPU 164 | """ 165 | num_workers = max(num_gpus, 1) 166 | self.task_queue = mp.Queue(maxsize=num_workers * 3) 167 | self.result_queue = mp.Queue(maxsize=num_workers * 3) 168 | self.procs = [] 169 | for gpuid in range(max(num_gpus, 1)): 170 | cfg = cfg.clone() 171 | cfg.defrost() 172 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" 173 | self.procs.append( 174 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) 175 | ) 176 | 177 | self.put_idx = 0 178 | self.get_idx = 0 179 | self.result_rank = [] 180 | self.result_data = [] 181 | 182 | for p in self.procs: 183 | p.start() 184 | atexit.register(self.shutdown) 185 | 186 | def put(self, image): 187 | self.put_idx += 1 188 | self.task_queue.put((self.put_idx, image)) 189 | 190 | def get(self): 191 | self.get_idx += 1 # the index needed for this request 192 | if len(self.result_rank) and self.result_rank[0] == self.get_idx: 193 | res = self.result_data[0] 194 | del self.result_data[0], self.result_rank[0] 195 | return res 196 | 197 | while True: 198 | # make sure the results are returned in the correct order 199 | idx, res = self.result_queue.get() 200 | if idx == self.get_idx: 201 | return res 202 | insert = bisect.bisect(self.result_rank, idx) 203 | self.result_rank.insert(insert, idx) 204 | self.result_data.insert(insert, res) 205 | 206 | def __len__(self): 207 | return self.put_idx - self.get_idx 208 | 209 | def __call__(self, image): 210 | self.put(image) 211 | return self.get() 212 | 213 | def shutdown(self): 214 | for _ in self.procs: 215 | self.task_queue.put(AsyncPredictor._StopToken()) 216 | 217 | @property 218 | def default_buffer_size(self): 219 | return len(self.procs) * 5 220 | -------------------------------------------------------------------------------- /figures/adeinferenceCOCO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/figures/adeinferenceCOCO.png -------------------------------------------------------------------------------- /figures/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/figures/fig1.png -------------------------------------------------------------------------------- /mask_former/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import data # register all new datasets 3 | from . import modeling 4 | 5 | # config 6 | from .config import add_mask_former_config 7 | 8 | # dataset loading 9 | from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper 10 | from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import ( 11 | MaskFormerPanopticDatasetMapper, 12 | ) 13 | from .data.dataset_mappers.mask_former_semantic_dataset_mapper import ( 14 | MaskFormerSemanticDatasetMapper, 15 | ) 16 | 17 | # models 18 | from .mask_former_model import MaskFormer 19 | from .test_time_augmentation import SemanticSegmentorWithTTA 20 | from .semantic_seg_zero import SemanticSegmentorGzero 21 | -------------------------------------------------------------------------------- /mask_former/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | from detectron2.config import CfgNode as CN 4 | 5 | 6 | def add_mask_former_config(cfg): 7 | """ 8 | Add config for MASK_FORMER. 9 | """ 10 | # data config 11 | # select the dataset mapper 12 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" 13 | 14 | cfg.DATASETS.VAL_ALL = ("coco_2017_val_all_stuff_sem_seg",) 15 | 16 | # Color augmentation 17 | cfg.INPUT.COLOR_AUG_SSD = False 18 | # We retry random cropping until no single category in semantic segmentation GT occupies more 19 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 20 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 21 | # Pad image and segmentation GT in dataset mapper. 22 | cfg.INPUT.SIZE_DIVISIBILITY = -1 23 | 24 | # solver config 25 | # weight decay on embedding 26 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 27 | # optimizer 28 | cfg.SOLVER.OPTIMIZER = "ADAMW" 29 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 30 | 31 | # mask_former model config 32 | cfg.MODEL.MASK_FORMER = CN() 33 | 34 | # loss 35 | cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True 36 | # TODO: maybe the no object weight need to be adjusted 37 | cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 38 | cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 39 | cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 40 | 41 | # transformer config 42 | cfg.MODEL.MASK_FORMER.NHEADS = 8 43 | cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 44 | cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 45 | cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 46 | cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 47 | cfg.MODEL.MASK_FORMER.PRE_NORM = False 48 | 49 | cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 50 | cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 51 | 52 | cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" 53 | cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False 54 | 55 | # for mask pool 56 | cfg.MODEL.MASK_FORMER.NUM_CLS_CONV = 0 57 | cfg.MODEL.MASK_FORMER.FUSION_WAY = 'add' 58 | cfg.MODEL.MASK_FORMER.SEGMENTS_MASK_THRESHOLD = 0.5 59 | # MASK_POOL_FROM in ["x", "mask_features"] 60 | cfg.MODEL.MASK_FORMER.MASK_POOL_FROM = "mask_features" 61 | cfg.MODEL.MASK_FORMER.CLS_DEC_LAYERS = 6 62 | cfg.MODEL.MASK_FORMER.CLS_PRE_NORM = False 63 | 64 | # mask_former inference config 65 | cfg.MODEL.MASK_FORMER.TEST = CN() 66 | cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False 67 | cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 68 | cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 69 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False 70 | 71 | # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) 72 | # you can use this config to override 73 | cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 74 | 75 | cfg.MODEL.MASK_FORMER.GZERO_CALIBRATE = -1.0 76 | cfg.MODEL.MASK_FORMER.ENSEMBLING = False 77 | cfg.MODEL.MASK_FORMER.ENSEMBLING_ALL_CLS = False 78 | cfg.MODEL.MASK_FORMER.GZERO_CALIBRATE_BEFORE = -1.0 79 | 80 | # pixel decoder config 81 | cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 82 | # adding transformer in pixel decoder 83 | cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 84 | # pixel decoder 85 | cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" 86 | # gzero calibrate 87 | cfg.MODEL.SEM_SEG_HEAD.GZERO_CALIBRATE = -1.0 88 | 89 | # swin transformer backbone 90 | cfg.MODEL.SWIN = CN() 91 | cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 92 | cfg.MODEL.SWIN.PATCH_SIZE = 4 93 | cfg.MODEL.SWIN.EMBED_DIM = 96 94 | cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 95 | cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 96 | cfg.MODEL.SWIN.WINDOW_SIZE = 7 97 | cfg.MODEL.SWIN.MLP_RATIO = 4.0 98 | cfg.MODEL.SWIN.QKV_BIAS = True 99 | cfg.MODEL.SWIN.QK_SCALE = None 100 | cfg.MODEL.SWIN.DROP_RATE = 0.0 101 | cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 102 | cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 103 | cfg.MODEL.SWIN.APE = False 104 | cfg.MODEL.SWIN.PATCH_NORM = True 105 | cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] 106 | 107 | # zero shot config 108 | cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json" 109 | cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json" 110 | cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_INDEXES = "datasets/coco/coco_stuff/split/seen_indexes.json" 111 | cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_INDEXES = "datasets/coco/coco_stuff/split/unseen_indexes.json" 112 | 113 | 114 | # cfg.MODEL.MASK_FORMER.TEST.CLIP_CLASSIFICATION = False 115 | cfg.MODEL.SEM_SEG_HEAD.CLIP_CLASSIFICATION = False 116 | cfg.MODEL.SEM_SEG_HEAD.DENSECLIP = False 117 | cfg.MODEL.SEM_SEG_HEAD.MASKATTENTIONPOOL = False 118 | cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED = "ViT-B/32" 119 | cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED_IMG = "ViT-B/32" 120 | 121 | cfg.MODEL.PROMPT_ENSEMBLE = False 122 | cfg.MODEL.PROMPT_ENSEMBLE_TYPE = "single" 123 | 124 | cfg.MODEL.CLIP_PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615] 125 | cfg.MODEL.CLIP_PIXEL_STD = [68.5005327, 66.6321579, 70.3231630] 126 | # three styles for clip classification, crop, mask, cropmask 127 | cfg.MODEL.CLIP_CLS_STYLE = "cropmask" 128 | 129 | # cfg.MODEL.MASK_FORMER.NUM_FEATURE_LEVELS = 4 130 | # cfg.MODEL.MASK_FORMER.DEC_N_POINTS = 4 131 | # cfg.MODEL.MASK_FORMER.ENC_N_POINTS = 4 132 | # cfg.MODEL.MASK_FORMER.TWO_STAGE = False 133 | cfg.MODEL.MASK_FORMER.CLUSTER_QUERIES = False 134 | cfg.MODEL.MASK_FORMER.USE_SEMANTIC_QUERY = False 135 | cfg.MODEL.MASK_FORMER.GRAD_SEMANTIC_QUERY = False 136 | cfg.MODEL.MASK_FORMER.SEMANTIC_QUERY_MULTIPLIER = 0.1 137 | cfg.MODEL.MASK_FORMER.INIT_QUERY_BY_SEMANTIC = False 138 | cfg.MODEL.MASK_FORMER.INIT_QUERY_BY_SEMANTIC_AS_ZERO = False 139 | cfg.MODEL.SEM_SEG_HEAD.WORDVEC = False 140 | cfg.MODEL.SEM_SEG_HEAD.TEMPERATURE = 0.01 141 | -------------------------------------------------------------------------------- /mask_former/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import datasets 3 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from detectron2.config import configurable 10 | from detectron2.data import detection_utils as utils 11 | from detectron2.data import transforms as T 12 | from detectron2.data.transforms import TransformGen 13 | from detectron2.structures import BitMasks, Instances 14 | 15 | __all__ = ["DETRPanopticDatasetMapper"] 16 | 17 | 18 | def build_transform_gen(cfg, is_train): 19 | """ 20 | Create a list of :class:`TransformGen` from config. 21 | Returns: 22 | list[TransformGen] 23 | """ 24 | if is_train: 25 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 26 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 27 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 28 | else: 29 | min_size = cfg.INPUT.MIN_SIZE_TEST 30 | max_size = cfg.INPUT.MAX_SIZE_TEST 31 | sample_style = "choice" 32 | if sample_style == "range": 33 | assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( 34 | len(min_size) 35 | ) 36 | 37 | logger = logging.getLogger(__name__) 38 | tfm_gens = [] 39 | if is_train: 40 | tfm_gens.append(T.RandomFlip()) 41 | tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 42 | if is_train: 43 | logger.info("TransformGens used in training: " + str(tfm_gens)) 44 | return tfm_gens 45 | 46 | 47 | # This is specifically designed for the COCO dataset. 48 | class DETRPanopticDatasetMapper: 49 | """ 50 | A callable which takes a dataset dict in Detectron2 Dataset format, 51 | and map it into a format used by MaskFormer. 52 | 53 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 54 | 55 | The callable currently does the following: 56 | 57 | 1. Read the image from "file_name" 58 | 2. Applies geometric transforms to the image and annotation 59 | 3. Find and applies suitable cropping to the image and annotation 60 | 4. Prepare image and annotation to Tensors 61 | """ 62 | 63 | @configurable 64 | def __init__( 65 | self, 66 | is_train=True, 67 | *, 68 | crop_gen, 69 | tfm_gens, 70 | image_format, 71 | ): 72 | """ 73 | NOTE: this interface is experimental. 74 | Args: 75 | is_train: for training or inference 76 | augmentations: a list of augmentations or deterministic transforms to apply 77 | crop_gen: crop augmentation 78 | tfm_gens: data augmentation 79 | image_format: an image format supported by :func:`detection_utils.read_image`. 80 | """ 81 | self.crop_gen = crop_gen 82 | self.tfm_gens = tfm_gens 83 | logging.getLogger(__name__).info( 84 | "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format( 85 | str(self.tfm_gens), str(self.crop_gen) 86 | ) 87 | ) 88 | 89 | self.img_format = image_format 90 | self.is_train = is_train 91 | 92 | @classmethod 93 | def from_config(cls, cfg, is_train=True): 94 | # Build augmentation 95 | if cfg.INPUT.CROP.ENABLED and is_train: 96 | crop_gen = [ 97 | T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), 98 | T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), 99 | ] 100 | else: 101 | crop_gen = None 102 | 103 | tfm_gens = build_transform_gen(cfg, is_train) 104 | 105 | ret = { 106 | "is_train": is_train, 107 | "crop_gen": crop_gen, 108 | "tfm_gens": tfm_gens, 109 | "image_format": cfg.INPUT.FORMAT, 110 | } 111 | return ret 112 | 113 | def __call__(self, dataset_dict): 114 | """ 115 | Args: 116 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 117 | 118 | Returns: 119 | dict: a format that builtin models in detectron2 accept 120 | """ 121 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 122 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 123 | utils.check_image_size(dataset_dict, image) 124 | 125 | if self.crop_gen is None: 126 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 127 | else: 128 | if np.random.rand() > 0.5: 129 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 130 | else: 131 | image, transforms = T.apply_transform_gens( 132 | self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image 133 | ) 134 | 135 | image_shape = image.shape[:2] # h, w 136 | 137 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 138 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 139 | # Therefore it's important to use torch.Tensor. 140 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 141 | 142 | if not self.is_train: 143 | # USER: Modify this if you want to keep them for some reason. 144 | dataset_dict.pop("annotations", None) 145 | return dataset_dict 146 | 147 | if "pan_seg_file_name" in dataset_dict: 148 | pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") 149 | segments_info = dataset_dict["segments_info"] 150 | 151 | # apply the same transformation to panoptic segmentation 152 | pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) 153 | 154 | from panopticapi.utils import rgb2id 155 | 156 | pan_seg_gt = rgb2id(pan_seg_gt) 157 | 158 | instances = Instances(image_shape) 159 | classes = [] 160 | masks = [] 161 | for segment_info in segments_info: 162 | class_id = segment_info["category_id"] 163 | if not segment_info["iscrowd"]: 164 | classes.append(class_id) 165 | masks.append(pan_seg_gt == segment_info["id"]) 166 | 167 | classes = np.array(classes) 168 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 169 | if len(masks) == 0: 170 | # Some image does not have annotation (all ignored) 171 | instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) 172 | else: 173 | masks = BitMasks( 174 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 175 | ) 176 | instances.gt_masks = masks.tensor 177 | 178 | dataset_dict["instances"] = instances 179 | 180 | return dataset_dict 181 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.data import detection_utils as utils 11 | from detectron2.data import transforms as T 12 | from detectron2.structures import BitMasks, Instances 13 | 14 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper 15 | 16 | __all__ = ["MaskFormerPanopticDatasetMapper"] 17 | 18 | 19 | class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper): 20 | """ 21 | A callable which takes a dataset dict in Detectron2 Dataset format, 22 | and map it into a format used by MaskFormer for panoptic segmentation. 23 | 24 | The callable currently does the following: 25 | 26 | 1. Read the image from "file_name" 27 | 2. Applies geometric transforms to the image and annotation 28 | 3. Find and applies suitable cropping to the image and annotation 29 | 4. Prepare image and annotation to Tensors 30 | """ 31 | 32 | @configurable 33 | def __init__( 34 | self, 35 | is_train=True, 36 | *, 37 | augmentations, 38 | image_format, 39 | ignore_label, 40 | size_divisibility, 41 | ): 42 | """ 43 | NOTE: this interface is experimental. 44 | Args: 45 | is_train: for training or inference 46 | augmentations: a list of augmentations or deterministic transforms to apply 47 | image_format: an image format supported by :func:`detection_utils.read_image`. 48 | ignore_label: the label that is ignored to evaluation 49 | size_divisibility: pad image size to be divisible by this value 50 | """ 51 | super().__init__( 52 | is_train, 53 | augmentations=augmentations, 54 | image_format=image_format, 55 | ignore_label=ignore_label, 56 | size_divisibility=size_divisibility, 57 | ) 58 | 59 | def __call__(self, dataset_dict): 60 | """ 61 | Args: 62 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 63 | 64 | Returns: 65 | dict: a format that builtin models in detectron2 accept 66 | """ 67 | assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!" 68 | 69 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 70 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 71 | utils.check_image_size(dataset_dict, image) 72 | 73 | # semantic segmentation 74 | if "sem_seg_file_name" in dataset_dict: 75 | # PyTorch transformation not implemented for uint16, so converting it to double first 76 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double") 77 | else: 78 | sem_seg_gt = None 79 | 80 | # panoptic segmentation 81 | if "pan_seg_file_name" in dataset_dict: 82 | pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") 83 | segments_info = dataset_dict["segments_info"] 84 | else: 85 | pan_seg_gt = None 86 | segments_info = None 87 | 88 | if pan_seg_gt is None: 89 | raise ValueError( 90 | "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format( 91 | dataset_dict["file_name"] 92 | ) 93 | ) 94 | 95 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 96 | aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) 97 | image = aug_input.image 98 | if sem_seg_gt is not None: 99 | sem_seg_gt = aug_input.sem_seg 100 | 101 | # apply the same transformation to panoptic segmentation 102 | pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) 103 | 104 | from panopticapi.utils import rgb2id 105 | 106 | pan_seg_gt = rgb2id(pan_seg_gt) 107 | 108 | # Pad image and segmentation label here! 109 | image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 110 | if sem_seg_gt is not None: 111 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 112 | pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long")) 113 | 114 | if self.size_divisibility > 0: 115 | image_size = (image.shape[-2], image.shape[-1]) 116 | padding_size = [ 117 | 0, 118 | self.size_divisibility - image_size[1], 119 | 0, 120 | self.size_divisibility - image_size[0], 121 | ] 122 | image = F.pad(image, padding_size, value=128).contiguous() 123 | if sem_seg_gt is not None: 124 | sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() 125 | pan_seg_gt = F.pad( 126 | pan_seg_gt, padding_size, value=0 127 | ).contiguous() # 0 is the VOID panoptic label 128 | 129 | image_shape = (image.shape[-2], image.shape[-1]) # h, w 130 | 131 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 132 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 133 | # Therefore it's important to use torch.Tensor. 134 | dataset_dict["image"] = image 135 | if sem_seg_gt is not None: 136 | dataset_dict["sem_seg"] = sem_seg_gt.long() 137 | 138 | if "annotations" in dataset_dict: 139 | raise ValueError("Pemantic segmentation dataset should not have 'annotations'.") 140 | 141 | # Prepare per-category binary masks 142 | pan_seg_gt = pan_seg_gt.numpy() 143 | instances = Instances(image_shape) 144 | classes = [] 145 | masks = [] 146 | for segment_info in segments_info: 147 | class_id = segment_info["category_id"] 148 | if not segment_info["iscrowd"]: 149 | classes.append(class_id) 150 | masks.append(pan_seg_gt == segment_info["id"]) 151 | 152 | classes = np.array(classes) 153 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 154 | if len(masks) == 0: 155 | # Some image does not have annotation (all ignored) 156 | instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) 157 | else: 158 | masks = BitMasks( 159 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 160 | ) 161 | instances.gt_masks = masks.tensor 162 | 163 | dataset_dict["instances"] = instances 164 | 165 | return dataset_dict 166 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.data import MetadataCatalog 11 | from detectron2.data import detection_utils as utils 12 | from detectron2.data import transforms as T 13 | from detectron2.projects.point_rend import ColorAugSSDTransform 14 | from detectron2.structures import BitMasks, Instances 15 | 16 | __all__ = ["MaskFormerSemanticDatasetMapper"] 17 | 18 | 19 | class MaskFormerSemanticDatasetMapper: 20 | """ 21 | A callable which takes a dataset dict in Detectron2 Dataset format, 22 | and map it into a format used by MaskFormer for semantic segmentation. 23 | 24 | The callable currently does the following: 25 | 26 | 1. Read the image from "file_name" 27 | 2. Applies geometric transforms to the image and annotation 28 | 3. Find and applies suitable cropping to the image and annotation 29 | 4. Prepare image and annotation to Tensors 30 | """ 31 | 32 | @configurable 33 | def __init__( 34 | self, 35 | is_train=True, 36 | *, 37 | augmentations, 38 | image_format, 39 | ignore_label, 40 | size_divisibility, 41 | ): 42 | """ 43 | NOTE: this interface is experimental. 44 | Args: 45 | is_train: for training or inference 46 | augmentations: a list of augmentations or deterministic transforms to apply 47 | image_format: an image format supported by :func:`detection_utils.read_image`. 48 | ignore_label: the label that is ignored to evaluation 49 | size_divisibility: pad image size to be divisible by this value 50 | """ 51 | self.is_train = is_train 52 | self.tfm_gens = augmentations 53 | self.img_format = image_format 54 | self.ignore_label = ignore_label 55 | self.size_divisibility = size_divisibility 56 | 57 | logger = logging.getLogger(__name__) 58 | mode = "training" if is_train else "inference" 59 | logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}") 60 | 61 | @classmethod 62 | def from_config(cls, cfg, is_train=True): 63 | # Build augmentation 64 | augs = [ 65 | T.ResizeShortestEdge( 66 | cfg.INPUT.MIN_SIZE_TRAIN, 67 | cfg.INPUT.MAX_SIZE_TRAIN, 68 | cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, 69 | ) 70 | ] 71 | if cfg.INPUT.CROP.ENABLED: 72 | augs.append( 73 | T.RandomCrop_CategoryAreaConstraint( 74 | cfg.INPUT.CROP.TYPE, 75 | cfg.INPUT.CROP.SIZE, 76 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, 77 | cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 78 | ) 79 | ) 80 | if cfg.INPUT.COLOR_AUG_SSD: 81 | augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT)) 82 | augs.append(T.RandomFlip()) 83 | 84 | # Assume always applies to the training set. 85 | dataset_names = cfg.DATASETS.TRAIN 86 | meta = MetadataCatalog.get(dataset_names[0]) 87 | ignore_label = meta.ignore_label 88 | 89 | ret = { 90 | "is_train": is_train, 91 | "augmentations": augs, 92 | "image_format": cfg.INPUT.FORMAT, 93 | "ignore_label": ignore_label, 94 | "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY, 95 | } 96 | return ret 97 | 98 | def __call__(self, dataset_dict): 99 | """ 100 | Args: 101 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 102 | 103 | Returns: 104 | dict: a format that builtin models in detectron2 accept 105 | """ 106 | assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!" 107 | 108 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 109 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 110 | utils.check_image_size(dataset_dict, image) 111 | 112 | if "sem_seg_file_name" in dataset_dict: 113 | # PyTorch transformation not implemented for uint16, so converting it to double first 114 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double") 115 | else: 116 | sem_seg_gt = None 117 | 118 | if sem_seg_gt is None: 119 | raise ValueError( 120 | "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( 121 | dataset_dict["file_name"] 122 | ) 123 | ) 124 | 125 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 126 | aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) 127 | image = aug_input.image 128 | sem_seg_gt = aug_input.sem_seg 129 | 130 | # Pad image and segmentation label here! 131 | image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 132 | if sem_seg_gt is not None: 133 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 134 | # import ipdb; ipdb.set_trace() 135 | if self.size_divisibility > 0: 136 | image_size = (image.shape[-2], image.shape[-1]) 137 | # The ori_size is not the real original size, but size before padding 138 | dataset_dict['ori_size'] = image_size 139 | padding_size = [ 140 | 0, 141 | self.size_divisibility - image_size[1], # w: (left, right) 142 | 0, 143 | self.size_divisibility - image_size[0], # h: 0,(top, bottom) 144 | ] 145 | image = F.pad(image, padding_size, value=128).contiguous() 146 | if sem_seg_gt is not None: 147 | sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() 148 | 149 | image_shape = (image.shape[-2], image.shape[-1]) # h, w 150 | 151 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 152 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 153 | # Therefore it's important to use torch.Tensor. 154 | dataset_dict["image"] = image 155 | # print('#########################################################################################') 156 | if sem_seg_gt is not None: 157 | dataset_dict["sem_seg"] = sem_seg_gt.long() 158 | 159 | if "annotations" in dataset_dict: 160 | raise ValueError("Semantic segmentation dataset should not have 'annotations'.") 161 | 162 | # Prepare per-category binary masks 163 | if sem_seg_gt is not None: 164 | sem_seg_gt = sem_seg_gt.numpy() 165 | instances = Instances(image_shape) 166 | classes = np.unique(sem_seg_gt) 167 | # remove ignored region 168 | classes = classes[classes != self.ignore_label] 169 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 170 | 171 | masks = [] 172 | for class_id in classes: 173 | masks.append(sem_seg_gt == class_id) 174 | 175 | if len(masks) == 0: 176 | # Some image does not have annotation (all ignored) 177 | instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) 178 | else: 179 | masks = BitMasks( 180 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 181 | ) 182 | instances.gt_masks = masks.tensor 183 | 184 | dataset_dict["instances"] = instances 185 | 186 | return dataset_dict 187 | -------------------------------------------------------------------------------- /mask_former/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import ( 3 | register_coco_stuff, 4 | register_ade20k_full_zero_freq, 5 | register_pascal_voc, 6 | ) 7 | -------------------------------------------------------------------------------- /mask_former/data/datasets/register_pascal_voc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from detectron2.data import DatasetCatalog, MetadataCatalog 5 | from detectron2.data.datasets import load_sem_seg 6 | import copy 7 | 8 | categories = [ 9 | {"name": "aeroplane", "id": 1, "trainId": 0}, 10 | {"name": "bicycle", "id": 2, "trainId": 1}, 11 | {"name": "bird", "id": 3, "trainId": 2}, 12 | {"name": "boat", "id": 4, "trainId": 3}, 13 | {"name": "bottle", "id": 5, "trainId": 4}, 14 | {"name": "bus", "id": 6, "trainId": 5}, 15 | {"name": "car", "id": 7, "trainId": 6}, 16 | {"name": "cat", "id": 8, "trainId": 7}, 17 | {"name": "chair", "id": 9, "trainId": 8}, 18 | {"name": "cow", "id": 10, "trainId": 9}, 19 | {"name": "diningtable", "id": 11, "trainId": 10}, 20 | {"name": "dog", "id": 12, "trainId": 11}, 21 | {"name": "horse", "id": 13, "trainId": 12}, 22 | {"name": "motorbike", "id": 14, "trainId": 13}, 23 | {"name": "person", "id": 15, "trainId": 14}, 24 | {"name": "potted plant", "id": 16, "trainId": 15}, 25 | {"name": "sheep", "id": 17, "trainId": 16}, 26 | {"name": "sofa", "id": 18, "trainId": 17}, 27 | {"name": "train", "id": 19, "trainId": 18}, 28 | {"name": "tvmonitor", "id": 20, "trainId": 19}] 29 | 30 | categories_seen = copy.deepcopy(categories[:15]) 31 | 32 | categories_unseen = copy.deepcopy(categories[15:]) 33 | for index, item in enumerate(categories_unseen): 34 | item["trainId"] = index 35 | 36 | def _get_pascal_voc_seen_meta(): 37 | # Id 0 is reserved for ignore_label, we change ignore_label for 0 38 | # to 255 in our pre-processing. 39 | stuff_ids = [k["id"] for k in categories_seen] 40 | assert len(stuff_ids) == 15, len(stuff_ids) 41 | 42 | # For semantic segmentation, this mapping maps from contiguous stuff id 43 | # (in [0, 91], used in models) to ids in the dataset (used for processing results) 44 | stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} 45 | stuff_classes = [k["name"] for k in categories_seen] 46 | 47 | ret = { 48 | "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, 49 | "stuff_classes": stuff_classes, 50 | } 51 | return ret 52 | 53 | # def register_pascal_voc_seen(root): 54 | # root = os.path.join(root, "VOCZERO") 55 | # meta = _get_pascal_voc_seen_meta() 56 | # for name, image_dirname, sem_seg_dirname in [ 57 | # ("train", "images/train", "annotations_detectron2/train_seen"), 58 | # ("test", "images/val", "annotations_detectron2/val_seen"), 59 | # ]: 60 | # image_dir = os.path.join(root, image_dirname) 61 | # gt_dir = os.path.join(root, sem_seg_dirname) 62 | # name = f"pascal_voc_{name}_seen_sem_seg" 63 | # DatasetCatalog.register( 64 | # name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") 65 | # ) 66 | # MetadataCatalog.get(name).set( 67 | # image_root=image_dir, 68 | # sem_seg_root=gt_dir, 69 | # evaluator_type="sem_seg", 70 | # ignore_label=255, 71 | # **meta, 72 | # ) 73 | 74 | def register_pascal_voc_seen(root): 75 | root = os.path.join(root, "VOCZERO") 76 | meta = _get_pascal_voc_seen_meta() 77 | for name, image_dirname, sem_seg_dirname in [ 78 | ("train", "images/train", "annotations_detectron2/train_seen"), 79 | ("test", "images/val", "annotations_detectron2/val_seen"), 80 | ]: 81 | image_dir = os.path.join(root, image_dirname) 82 | gt_dir = os.path.join(root, sem_seg_dirname) 83 | name = f"pascal_voc_{name}_seen_sem_seg" 84 | DatasetCatalog.register( 85 | name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") 86 | ) 87 | # import ipdb; ipdb.set_trace() 88 | MetadataCatalog.get(name).set( 89 | image_root=image_dir, 90 | sem_seg_root=gt_dir, 91 | evaluator_type="sem_seg", 92 | ignore_label=255, 93 | **meta, 94 | ) 95 | 96 | 97 | def _get_pascal_voc_val_unseen_meta(): 98 | # Id 0 is reserved for ignore_label, we change ignore_label for 0 99 | # to 255 in our pre-processing. 100 | stuff_ids = [k["id"] for k in categories_unseen] 101 | assert len(stuff_ids) == 5, len(stuff_ids) 102 | 103 | # For semantic segmentation, this mapping maps from contiguous stuff id 104 | # (in [0, 91], used in models) to ids in the dataset (used for processing results) 105 | stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} 106 | stuff_classes = [k["name"] for k in categories_unseen] 107 | 108 | ret = { 109 | "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, 110 | "stuff_classes": stuff_classes, 111 | } 112 | return ret 113 | 114 | def register_coco_stuff_val_unseen(root): 115 | root = os.path.join(root, "VOCZERO") 116 | meta = _get_pascal_voc_val_unseen_meta() 117 | 118 | name = 'val_unseen' 119 | image_dirname = "images/val" 120 | sem_seg_dirname = "annotations_detectron2/val_unseen" 121 | image_dir = os.path.join(root, image_dirname) 122 | gt_dir = os.path.join(root, sem_seg_dirname) 123 | name = f"pascal_voc_{name}_sem_seg" 124 | DatasetCatalog.register( 125 | name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") 126 | ) 127 | val_extra_classes = [k["name"] for k in categories_unseen] 128 | MetadataCatalog.get(name).set( 129 | val_extra_classes=val_extra_classes, 130 | image_root=image_dir, 131 | sem_seg_root=gt_dir, 132 | evaluator_type="sem_seg", 133 | ignore_label=255, 134 | **meta, 135 | ) 136 | 137 | def _get_pascal_voc_stuff_meta(): 138 | # Id 0 is reserved for ignore_label, we change ignore_label for 0 139 | # to 255 in our pre-processing. 140 | stuff_ids = [k["id"] for k in categories] 141 | assert len(stuff_ids) == 20, len(stuff_ids) 142 | 143 | # For semantic segmentation, this mapping maps from contiguous stuff id 144 | # (in [0, 91], used in models) to ids in the dataset (used for processing results) 145 | stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} 146 | stuff_classes = [k["name"] for k in categories] 147 | 148 | ret = { 149 | "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, 150 | "stuff_classes": stuff_classes, 151 | } 152 | return ret 153 | 154 | def register_voc_stuff_val_all(root): 155 | root = os.path.join(root, "VOCZERO") 156 | meta = _get_pascal_voc_stuff_meta() 157 | name = 'val_all' 158 | image_dirname = "images/val" 159 | sem_seg_dirname = "annotations_detectron2/val_all" 160 | image_dir = os.path.join(root, image_dirname) 161 | gt_dir = os.path.join(root, sem_seg_dirname) 162 | name = f"pascal_voc_{name}_sem_seg" 163 | DatasetCatalog.register( 164 | name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") 165 | ) 166 | 167 | val_extra_classes = [k["name"] for k in categories_unseen] 168 | MetadataCatalog.get(name).set( 169 | val_extra_classes=val_extra_classes, 170 | image_root=image_dir, 171 | sem_seg_root=gt_dir, 172 | # evaluator_type="sem_seg", 173 | evaluator_type="sem_seg_gzero", 174 | ignore_label=255, 175 | **meta, 176 | ) 177 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 178 | register_pascal_voc_seen(_root) 179 | register_coco_stuff_val_unseen(_root) 180 | register_voc_stuff_val_all(_root) 181 | -------------------------------------------------------------------------------- /mask_former/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/mask_former/evaluation/__init__.py -------------------------------------------------------------------------------- /mask_former/evaluation/sem_seg_evaluation_gzero.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import itertools 3 | import json 4 | import logging 5 | import numpy as np 6 | import os 7 | from collections import OrderedDict 8 | import PIL.Image as Image 9 | import pycocotools.mask as mask_util 10 | import torch 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.utils.comm import all_gather, is_main_process, synchronize 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from .evaluator import DatasetEvaluator 17 | 18 | 19 | class SemSegGzeroEvaluator(DatasetEvaluator): 20 | """ 21 | Evaluate semantic segmentation metrics. 22 | """ 23 | 24 | def __init__( 25 | self, dataset_name, distributed, output_dir=None, *, num_classes=None, ignore_label=None 26 | ): 27 | """ 28 | Args: 29 | dataset_name (str): name of the dataset to be evaluated. 30 | distributed (True): if True, will collect results from all ranks for evaluation. 31 | Otherwise, will evaluate the results in the current process. 32 | output_dir (str): an output directory to dump results. 33 | num_classes, ignore_label: deprecated argument 34 | """ 35 | self._logger = logging.getLogger(__name__) 36 | if num_classes is not None: 37 | self._logger.warn( 38 | "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata." 39 | ) 40 | if ignore_label is not None: 41 | self._logger.warn( 42 | "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata." 43 | ) 44 | self._dataset_name = dataset_name 45 | self._distributed = distributed 46 | self._output_dir = output_dir 47 | 48 | self._cpu_device = torch.device("cpu") 49 | 50 | self.input_file_to_gt_file = { 51 | dataset_record["file_name"]: dataset_record["sem_seg_file_name"] 52 | for dataset_record in DatasetCatalog.get(dataset_name) 53 | } 54 | 55 | meta = MetadataCatalog.get(dataset_name) 56 | # Dict that maps contiguous training ids to COCO category ids 57 | try: 58 | c2d = meta.stuff_dataset_id_to_contiguous_id 59 | self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()} 60 | except AttributeError: 61 | self._contiguous_id_to_dataset_id = None 62 | self._class_names = meta.stuff_classes 63 | self._val_extra_classes = meta.val_extra_classes 64 | self._num_classes = len(meta.stuff_classes) 65 | if num_classes is not None: 66 | assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}" 67 | self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label 68 | 69 | def reset(self): 70 | self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64) 71 | self._predictions = [] 72 | 73 | def process(self, inputs, outputs): 74 | """ 75 | Args: 76 | inputs: the inputs to a model. 77 | It is a list of dicts. Each dict corresponds to an image and 78 | contains keys like "height", "width", "file_name". 79 | outputs: the outputs of a model. It is either list of semantic segmentation predictions 80 | (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic 81 | segmentation prediction in the same format. 82 | """ 83 | for input, output in zip(inputs, outputs): 84 | output = output["sem_seg"].argmax(dim=0).to(self._cpu_device) 85 | pred = np.array(output, dtype=np.int) 86 | with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f: 87 | gt = np.array(Image.open(f), dtype=np.int) 88 | 89 | gt[gt == self._ignore_label] = self._num_classes 90 | 91 | self._conf_matrix += np.bincount( 92 | (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1), 93 | minlength=self._conf_matrix.size, 94 | ).reshape(self._conf_matrix.shape) 95 | 96 | self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"])) 97 | 98 | def evaluate(self): 99 | """ 100 | Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval): 101 | 102 | * Mean intersection-over-union averaged across classes (mIoU) 103 | * Frequency Weighted IoU (fwIoU) 104 | * Mean pixel accuracy averaged across classes (mACC) 105 | * Pixel Accuracy (pACC) 106 | """ 107 | if self._distributed: 108 | synchronize() 109 | conf_matrix_list = all_gather(self._conf_matrix) 110 | self._predictions = all_gather(self._predictions) 111 | self._predictions = list(itertools.chain(*self._predictions)) 112 | if not is_main_process(): 113 | return 114 | 115 | self._conf_matrix = np.zeros_like(self._conf_matrix) 116 | for conf_matrix in conf_matrix_list: 117 | self._conf_matrix += conf_matrix 118 | 119 | if self._output_dir: 120 | PathManager.mkdirs(self._output_dir) 121 | file_path = os.path.join(self._output_dir, "sem_seg_predictions.json") 122 | with PathManager.open(file_path, "w") as f: 123 | f.write(json.dumps(self._predictions)) 124 | 125 | acc = np.full(self._num_classes, np.nan, dtype=np.float) 126 | iou = np.full(self._num_classes, np.nan, dtype=np.float) 127 | tp = self._conf_matrix.diagonal()[:-1].astype(np.float) 128 | pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float) 129 | class_weights = pos_gt / np.sum(pos_gt) 130 | pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float) 131 | acc_valid = pos_gt > 0 132 | acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid] 133 | iou_valid = (pos_gt + pos_pred) > 0 134 | union = pos_gt + pos_pred - tp 135 | iou[acc_valid] = tp[acc_valid] / union[acc_valid] 136 | macc = np.sum(acc[acc_valid]) / np.sum(acc_valid) 137 | miou = np.sum(iou[acc_valid]) / np.sum(iou_valid) 138 | fiou = np.sum(iou[acc_valid] * class_weights[acc_valid]) 139 | pacc = np.sum(tp) / np.sum(pos_gt) 140 | 141 | 142 | seen_IoU = 0 143 | unseen_IoU = 0 144 | seen_acc = 0 145 | unseen_acc = 0 146 | res = {} 147 | res["mIoU"] = 100 * miou 148 | res["fwIoU"] = 100 * fiou 149 | for i, name in enumerate(self._class_names): 150 | res["IoU-{}".format(name)] = 100 * iou[i] 151 | if name in self._val_extra_classes: 152 | unseen_IoU = unseen_IoU + 100 * iou[i] 153 | else: 154 | seen_IoU = seen_IoU + 100 * iou[i] 155 | unseen_IoU = unseen_IoU / len(self._val_extra_classes) 156 | seen_IoU = seen_IoU / (self._num_classes - len(self._val_extra_classes)) 157 | res["mACC"] = 100 * macc 158 | res["pACC"] = 100 * pacc 159 | for i, name in enumerate(self._class_names): 160 | res["ACC-{}".format(name)] = 100 * acc[i] 161 | if name in self._val_extra_classes: 162 | unseen_acc = unseen_acc + 100 * iou[i] 163 | else: 164 | seen_acc = seen_acc + 100 * iou[i] 165 | unseen_acc = unseen_acc / len(self._val_extra_classes) 166 | seen_acc = seen_acc / (self._num_classes - len(self._val_extra_classes)) 167 | res["unseen_IoU"] = unseen_IoU 168 | res["seen_IoU"] = seen_IoU 169 | res["unseen_acc"] = unseen_acc 170 | res["seen_acc"] = seen_acc 171 | if self._output_dir: 172 | file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth") 173 | with PathManager.open(file_path, "wb") as f: 174 | torch.save(res, f) 175 | results = OrderedDict({"sem_seg": res}) 176 | self._logger.info(results) 177 | return results 178 | 179 | def encode_json_sem_seg(self, sem_seg, input_file_name): 180 | """ 181 | Convert semantic segmentation to COCO stuff format with segments encoded as RLEs. 182 | See http://cocodataset.org/#format-results 183 | """ 184 | json_list = [] 185 | for label in np.unique(sem_seg): 186 | if self._contiguous_id_to_dataset_id is not None: 187 | assert ( 188 | label in self._contiguous_id_to_dataset_id 189 | ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name) 190 | dataset_id = self._contiguous_id_to_dataset_id[label] 191 | else: 192 | dataset_id = int(label) 193 | mask = (sem_seg == label).astype(np.uint8) 194 | mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0] 195 | mask_rle["counts"] = mask_rle["counts"].decode("utf-8") 196 | json_list.append( 197 | {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle} 198 | ) 199 | return json_list 200 | -------------------------------------------------------------------------------- /mask_former/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .backbone.swin import D2SwinTransformer 3 | from .heads.mask_former_head import MaskFormerHead 4 | from .heads.zeg_former_head import ZegFormerHead 5 | from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead 6 | from .heads.pixel_decoder import BasePixelDecoder 7 | from .heads.zeroshot_per_pixel_baseline import ZeroshotPerPixelBaselineHead -------------------------------------------------------------------------------- /mask_former/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/modeling/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | """ 4 | MaskFormer criterion. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from detectron2.utils.comm import get_world_size 11 | 12 | from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list 13 | 14 | 15 | def dice_loss(inputs, targets, num_masks): 16 | """ 17 | Compute the DICE loss, similar to generalized IOU for masks 18 | Args: 19 | inputs: A float tensor of arbitrary shape. 20 | The predictions for each example. 21 | targets: A float tensor with the same shape as inputs. Stores the binary 22 | classification label for each element in inputs 23 | (0 for the negative class and 1 for the positive class). 24 | """ 25 | inputs = inputs.sigmoid() 26 | inputs = inputs.flatten(1) 27 | numerator = 2 * (inputs * targets).sum(-1) 28 | denominator = inputs.sum(-1) + targets.sum(-1) 29 | loss = 1 - (numerator + 1) / (denominator + 1) 30 | return loss.sum() / num_masks 31 | 32 | 33 | def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2): 34 | """ 35 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 36 | Args: 37 | inputs: A float tensor of arbitrary shape. 38 | The predictions for each example. 39 | targets: A float tensor with the same shape as inputs. Stores the binary 40 | classification label for each element in inputs 41 | (0 for the negative class and 1 for the positive class). 42 | alpha: (optional) Weighting factor in range (0,1) to balance 43 | positive vs negative examples. Default = -1 (no weighting). 44 | gamma: Exponent of the modulating factor (1 - p_t) to 45 | balance easy vs hard examples. 46 | Returns: 47 | Loss tensor 48 | """ 49 | prob = inputs.sigmoid() 50 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 51 | p_t = prob * targets + (1 - prob) * (1 - targets) 52 | loss = ce_loss * ((1 - p_t) ** gamma) 53 | 54 | if alpha >= 0: 55 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 56 | loss = alpha_t * loss 57 | 58 | return loss.mean(1).sum() / num_masks 59 | 60 | 61 | class SetCriterion(nn.Module): 62 | """This class computes the loss for DETR. 63 | The process happens in two steps: 64 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 65 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 66 | """ 67 | 68 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): 69 | """Create the criterion. 70 | Parameters: 71 | num_classes: number of object categories, omitting the special no-object category 72 | matcher: module able to compute a matching between targets and proposals 73 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 74 | eos_coef: relative classification weight applied to the no-object category 75 | losses: list of all the losses to be applied. See get_loss for list of available losses. 76 | """ 77 | super().__init__() 78 | self.num_classes = num_classes 79 | self.matcher = matcher 80 | self.weight_dict = weight_dict 81 | self.eos_coef = eos_coef 82 | self.losses = losses 83 | empty_weight = torch.ones(self.num_classes + 1) 84 | empty_weight[-1] = self.eos_coef 85 | self.register_buffer("empty_weight", empty_weight) 86 | 87 | def loss_labels(self, outputs, targets, indices, num_masks): 88 | """Classification loss (NLL) 89 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 90 | """ 91 | assert "pred_logits" in outputs 92 | src_logits = outputs["pred_logits"] 93 | 94 | idx = self._get_src_permutation_idx(indices) 95 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 96 | target_classes = torch.full( 97 | src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device 98 | ) 99 | target_classes[idx] = target_classes_o 100 | 101 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) 102 | losses = {"loss_ce": loss_ce} 103 | return losses 104 | 105 | def loss_masks(self, outputs, targets, indices, num_masks): 106 | """Compute the losses related to the masks: the focal loss and the dice loss. 107 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 108 | """ 109 | assert "pred_masks" in outputs 110 | 111 | src_idx = self._get_src_permutation_idx(indices) 112 | tgt_idx = self._get_tgt_permutation_idx(indices) 113 | src_masks = outputs["pred_masks"] 114 | src_masks = src_masks[src_idx] 115 | masks = [t["masks"] for t in targets] 116 | # TODO use valid to mask invalid areas due to padding in loss 117 | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() 118 | target_masks = target_masks.to(src_masks) 119 | target_masks = target_masks[tgt_idx] 120 | 121 | # upsample predictions to the target size 122 | src_masks = F.interpolate( 123 | src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False 124 | ) 125 | src_masks = src_masks[:, 0].flatten(1) 126 | 127 | target_masks = target_masks.flatten(1) 128 | target_masks = target_masks.view(src_masks.shape) 129 | losses = { 130 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks), 131 | "loss_dice": dice_loss(src_masks, target_masks, num_masks), 132 | } 133 | return losses 134 | 135 | def _get_src_permutation_idx(self, indices): 136 | # permute predictions following indices 137 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 138 | src_idx = torch.cat([src for (src, _) in indices]) 139 | return batch_idx, src_idx 140 | 141 | def _get_tgt_permutation_idx(self, indices): 142 | # permute targets following indices 143 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 144 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 145 | return batch_idx, tgt_idx 146 | 147 | def get_loss(self, loss, outputs, targets, indices, num_masks): 148 | loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} 149 | assert loss in loss_map, f"do you really want to compute {loss} loss?" 150 | return loss_map[loss](outputs, targets, indices, num_masks) 151 | 152 | def forward(self, outputs, targets): 153 | """This performs the loss computation. 154 | Parameters: 155 | outputs: dict of tensors, see the output specification of the model for the format 156 | targets: list of dicts, such that len(targets) == batch_size. 157 | The expected keys in each dict depends on the losses applied, see each loss' doc 158 | """ 159 | # outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} 160 | # TODO: check the refactor 161 | outputs_without_aux = {k: v for k, v in outputs.items() if k in ["pred_logits", "pred_masks"]} 162 | 163 | # Retrieve the matching between the outputs of the last layer and the targets 164 | indices = self.matcher(outputs_without_aux, targets) 165 | 166 | # Compute the average number of target boxes accross all nodes, for normalization purposes 167 | num_masks = sum(len(t["labels"]) for t in targets) 168 | num_masks = torch.as_tensor( 169 | [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device 170 | ) 171 | if is_dist_avail_and_initialized(): 172 | torch.distributed.all_reduce(num_masks) 173 | num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() 174 | 175 | # Compute all the requested losses 176 | losses = {} 177 | for loss in self.losses: 178 | losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) 179 | 180 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 181 | if "aux_outputs" in outputs: 182 | for i, aux_outputs in enumerate(outputs["aux_outputs"]): 183 | indices = self.matcher(aux_outputs, targets) 184 | for loss in self.losses: 185 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) 186 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 187 | losses.update(l_dict) 188 | 189 | return losses 190 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/mask_former_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from copy import deepcopy 4 | from typing import Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import fvcore.nn.weight_init as weight_init 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from detectron2.config import configurable 11 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 12 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 13 | 14 | from ..transformer.transformer_predictor import TransformerPredictor 15 | from .pixel_decoder import build_pixel_decoder 16 | 17 | 18 | @SEM_SEG_HEADS_REGISTRY.register() 19 | class MaskFormerHead(nn.Module): 20 | 21 | _version = 2 22 | 23 | def _load_from_state_dict( 24 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 25 | ): 26 | version = local_metadata.get("version", None) 27 | if version is None or version < 2: 28 | # Do not warn if train from scratch 29 | scratch = True 30 | logger = logging.getLogger(__name__) 31 | for k in list(state_dict.keys()): 32 | newk = k 33 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 34 | newk = k.replace(prefix, prefix + "pixel_decoder.") 35 | # logger.debug(f"{k} ==> {newk}") 36 | if newk != k: 37 | state_dict[newk] = state_dict[k] 38 | del state_dict[k] 39 | scratch = False 40 | 41 | if not scratch: 42 | logger.warning( 43 | f"Weight format of {self.__class__.__name__} have changed! " 44 | "Please upgrade your models. Applying automatic conversion now ..." 45 | ) 46 | 47 | @configurable 48 | def __init__( 49 | self, 50 | input_shape: Dict[str, ShapeSpec], 51 | *, 52 | num_classes: int, 53 | pixel_decoder: nn.Module, 54 | loss_weight: float = 1.0, 55 | ignore_value: int = -1, 56 | # extra parameters 57 | transformer_predictor: nn.Module, 58 | transformer_in_feature: str, 59 | ): 60 | """ 61 | NOTE: this interface is experimental. 62 | Args: 63 | input_shape: shapes (channels and stride) of the input features 64 | num_classes: number of classes to predict 65 | pixel_decoder: the pixel decoder module 66 | loss_weight: loss weight 67 | ignore_value: category id to be ignored during training. 68 | transformer_predictor: the transformer decoder that makes prediction 69 | transformer_in_feature: input feature name to the transformer_predictor 70 | """ 71 | super().__init__() 72 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 73 | self.in_features = [k for k, v in input_shape] 74 | feature_strides = [v.stride for k, v in input_shape] 75 | feature_channels = [v.channels for k, v in input_shape] 76 | 77 | self.ignore_value = ignore_value 78 | self.common_stride = 4 79 | self.loss_weight = loss_weight 80 | 81 | self.pixel_decoder = pixel_decoder 82 | self.predictor = transformer_predictor 83 | self.transformer_in_feature = transformer_in_feature 84 | 85 | self.num_classes = num_classes 86 | 87 | @classmethod 88 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 89 | return { 90 | "input_shape": { 91 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 92 | }, 93 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 94 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 95 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 96 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 97 | "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, 98 | "transformer_predictor": TransformerPredictor( 99 | cfg, 100 | cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 101 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" 102 | else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, 103 | mask_classification=True, 104 | ), 105 | } 106 | 107 | def forward(self, features, images_tensor=None, ori_sizes=None): 108 | return self.layers(features) 109 | 110 | def layers(self, features): 111 | mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) 112 | if self.transformer_in_feature == "transformer_encoder": 113 | assert ( 114 | transformer_encoder_features is not None 115 | ), "Please use the TransformerEncoderPixelDecoder." 116 | predictions = self.predictor(transformer_encoder_features, mask_features) 117 | else: 118 | predictions = self.predictor(features[self.transformer_in_feature], mask_features) 119 | return predictions 120 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/per_pixel_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from typing import Callable, Dict, List, Optional, Tuple, Union 4 | 5 | import fvcore.nn.weight_init as weight_init 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 11 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 12 | 13 | from ..transformer.transformer_predictor import TransformerPredictor 14 | from .pixel_decoder import build_pixel_decoder 15 | 16 | 17 | @SEM_SEG_HEADS_REGISTRY.register() 18 | class PerPixelBaselineHead(nn.Module): 19 | 20 | _version = 2 21 | 22 | def _load_from_state_dict( 23 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 24 | ): 25 | version = local_metadata.get("version", None) 26 | if version is None or version < 2: 27 | logger = logging.getLogger(__name__) 28 | # Do not warn if train from scratch 29 | scratch = True 30 | logger = logging.getLogger(__name__) 31 | for k in list(state_dict.keys()): 32 | newk = k 33 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 34 | newk = k.replace(prefix, prefix + "pixel_decoder.") 35 | # logger.warning(f"{k} ==> {newk}") 36 | if newk != k: 37 | state_dict[newk] = state_dict[k] 38 | del state_dict[k] 39 | scratch = False 40 | 41 | if not scratch: 42 | logger.warning( 43 | f"Weight format of {self.__class__.__name__} have changed! " 44 | "Please upgrade your models. Applying automatic conversion now ..." 45 | ) 46 | 47 | @configurable 48 | def __init__( 49 | self, 50 | input_shape: Dict[str, ShapeSpec], 51 | *, 52 | num_classes: int, 53 | pixel_decoder: nn.Module, 54 | loss_weight: float = 1.0, 55 | ignore_value: int = -1, 56 | ): 57 | """ 58 | NOTE: this interface is experimental. 59 | Args: 60 | input_shape: shapes (channels and stride) of the input features 61 | num_classes: number of classes to predict 62 | pixel_decoder: the pixel decoder module 63 | loss_weight: loss weight 64 | ignore_value: category id to be ignored during training. 65 | """ 66 | super().__init__() 67 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 68 | self.in_features = [k for k, v in input_shape] 69 | feature_strides = [v.stride for k, v in input_shape] 70 | feature_channels = [v.channels for k, v in input_shape] 71 | 72 | self.ignore_value = ignore_value 73 | self.common_stride = 4 74 | self.loss_weight = loss_weight 75 | 76 | self.pixel_decoder = pixel_decoder 77 | self.predictor = Conv2d( 78 | self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0 79 | ) 80 | weight_init.c2_msra_fill(self.predictor) 81 | 82 | @classmethod 83 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 84 | return { 85 | "input_shape": { 86 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 87 | }, 88 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 89 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 90 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 91 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 92 | } 93 | 94 | def forward(self, features, targets=None): 95 | """ 96 | Returns: 97 | In training, returns (None, dict of losses) 98 | In inference, returns (CxHxW logits, {}) 99 | """ 100 | x = self.layers(features) 101 | if self.training: 102 | return None, self.losses(x, targets) 103 | else: 104 | x = F.interpolate( 105 | x, scale_factor=self.common_stride, mode="bilinear", align_corners=False 106 | ) 107 | return x, {} 108 | 109 | def layers(self, features): 110 | x, _ = self.pixel_decoder.forward_features(features) 111 | x = self.predictor(x) 112 | return x 113 | 114 | def losses(self, predictions, targets): 115 | predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 116 | predictions = F.interpolate( 117 | predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False 118 | ) 119 | loss = F.cross_entropy( 120 | predictions, targets, reduction="mean", ignore_index=self.ignore_value 121 | ) 122 | losses = {"loss_sem_seg": loss * self.loss_weight} 123 | return losses 124 | 125 | 126 | @SEM_SEG_HEADS_REGISTRY.register() 127 | class PerPixelBaselinePlusHead(PerPixelBaselineHead): 128 | def _load_from_state_dict( 129 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 130 | ): 131 | version = local_metadata.get("version", None) 132 | if version is None or version < 2: 133 | # Do not warn if train from scratch 134 | scratch = True 135 | logger = logging.getLogger(__name__) 136 | for k in list(state_dict.keys()): 137 | newk = k 138 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 139 | newk = k.replace(prefix, prefix + "pixel_decoder.") 140 | logger.debug(f"{k} ==> {newk}") 141 | if newk != k: 142 | state_dict[newk] = state_dict[k] 143 | del state_dict[k] 144 | scratch = False 145 | 146 | if not scratch: 147 | logger.warning( 148 | f"Weight format of {self.__class__.__name__} have changed! " 149 | "Please upgrade your models. Applying automatic conversion now ..." 150 | ) 151 | 152 | @configurable 153 | def __init__( 154 | self, 155 | input_shape: Dict[str, ShapeSpec], 156 | *, 157 | # extra parameters 158 | transformer_predictor: nn.Module, 159 | transformer_in_feature: str, 160 | deep_supervision: bool, 161 | # inherit parameters 162 | num_classes: int, 163 | pixel_decoder: nn.Module, 164 | loss_weight: float = 1.0, 165 | ignore_value: int = -1, 166 | ): 167 | """ 168 | NOTE: this interface is experimental. 169 | Args: 170 | input_shape: shapes (channels and stride) of the input features 171 | transformer_predictor: the transformer decoder that makes prediction 172 | transformer_in_feature: input feature name to the transformer_predictor 173 | deep_supervision: whether or not to add supervision to the output of 174 | every transformer decoder layer 175 | num_classes: number of classes to predict 176 | pixel_decoder: the pixel decoder module 177 | loss_weight: loss weight 178 | ignore_value: category id to be ignored during training. 179 | """ 180 | super().__init__( 181 | input_shape, 182 | num_classes=num_classes, 183 | pixel_decoder=pixel_decoder, 184 | loss_weight=loss_weight, 185 | ignore_value=ignore_value, 186 | ) 187 | 188 | del self.predictor 189 | 190 | self.predictor = transformer_predictor 191 | self.transformer_in_feature = transformer_in_feature 192 | self.deep_supervision = deep_supervision 193 | 194 | @classmethod 195 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 196 | ret = super().from_config(cfg, input_shape) 197 | ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE 198 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder": 199 | in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 200 | else: 201 | in_channels = input_shape[ret["transformer_in_feature"]].channels 202 | ret["transformer_predictor"] = TransformerPredictor( 203 | cfg, in_channels, mask_classification=False 204 | ) 205 | ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 206 | return ret 207 | 208 | def forward(self, features, targets=None): 209 | """ 210 | Returns: 211 | In training, returns (None, dict of losses) 212 | In inference, returns (CxHxW logits, {}) 213 | """ 214 | x, aux_outputs = self.layers(features) 215 | if self.training: 216 | if self.deep_supervision: 217 | losses = self.losses(x, targets) 218 | for i, aux_output in enumerate(aux_outputs): 219 | losses["loss_sem_seg" + f"_{i}"] = self.losses( 220 | aux_output["pred_masks"], targets 221 | )["loss_sem_seg"] 222 | return None, losses 223 | else: 224 | return None, self.losses(x, targets) 225 | else: 226 | x = F.interpolate( 227 | x, scale_factor=self.common_stride, mode="bilinear", align_corners=False 228 | ) 229 | return x, {} 230 | 231 | def layers(self, features): 232 | mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) 233 | if self.transformer_in_feature == "transformer_encoder": 234 | assert ( 235 | transformer_encoder_features is not None 236 | ), "Please use the TransformerEncoderPixelDecoder." 237 | predictions = self.predictor(transformer_encoder_features, mask_features) 238 | else: 239 | predictions = self.predictor(features[self.transformer_in_feature], mask_features) 240 | if self.deep_supervision: 241 | return predictions["pred_masks"], predictions["aux_outputs"] 242 | else: 243 | return predictions["pred_masks"], None 244 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/zeg_former_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from copy import deepcopy 4 | from typing import Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import fvcore.nn.weight_init as weight_init 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from detectron2.config import configurable 11 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 12 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 13 | 14 | # from ..transformer.transformer_predictor import TransformerPredictor 15 | from ..transformer.transformer_zeroshot_predictor import TransformerZeroshotPredictor 16 | from .pixel_decoder import build_pixel_decoder 17 | 18 | 19 | @SEM_SEG_HEADS_REGISTRY.register() 20 | class ZegFormerHead(nn.Module): 21 | 22 | @configurable 23 | def __init__( 24 | self, 25 | input_shape: Dict[str, ShapeSpec], 26 | *, 27 | num_classes: int, 28 | pixel_decoder: nn.Module, 29 | loss_weight: float = 1.0, 30 | ignore_value: int = -1, 31 | # extra parameters 32 | transformer_predictor: nn.Module, 33 | transformer_in_feature: str, 34 | ): 35 | """ 36 | NOTE: this interface is experimental. 37 | Args: 38 | input_shape: shapes (channels and stride) of the input features 39 | num_classes: number of classes to predict 40 | pixel_decoder: the pixel decoder module 41 | loss_weight: loss weight 42 | ignore_value: category id to be ignored during training. 43 | transformer_predictor: the transformer decoder that makes prediction 44 | transformer_in_feature: input feature name to the transformer_predictor 45 | """ 46 | super().__init__() 47 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 48 | self.in_features = [k for k, v in input_shape] 49 | # feature_strides = [v.stride for k, v in input_shape] 50 | # feature_channels = [v.channels for k, v in input_shape] 51 | 52 | self.ignore_value = ignore_value 53 | self.common_stride = 4 54 | self.loss_weight = loss_weight 55 | 56 | self.pixel_decoder = pixel_decoder 57 | self.predictor = transformer_predictor 58 | self.transformer_in_feature = transformer_in_feature 59 | 60 | self.num_classes = num_classes 61 | 62 | @classmethod 63 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 64 | # import pdb; pdb.set_trace() 65 | return { 66 | "input_shape": { 67 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 68 | }, 69 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 70 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 71 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 72 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 73 | "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, 74 | # TODO: check if other places related to TransformerPredictor need modification 75 | # "transformer_predictor": TransformerPredictor( 76 | "transformer_predictor": TransformerZeroshotPredictor( 77 | cfg, 78 | cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 79 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" 80 | else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, 81 | mask_classification=True, 82 | ), 83 | } 84 | 85 | def forward(self, features, images_tensor=None, ori_sizes=None): 86 | 87 | return self.layers(features, images_tensor, ori_sizes) 88 | 89 | def layers(self, features, images_tensor=None, ori_sizes=None): 90 | mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) 91 | if self.transformer_in_feature == "transformer_encoder": 92 | assert ( 93 | transformer_encoder_features is not None 94 | ), "Please use the TransformerEncoderPixelDecoder." 95 | predictions = self.predictor(transformer_encoder_features, mask_features, images_tensor, ori_sizes) 96 | else: 97 | predictions = self.predictor(features[self.transformer_in_feature], mask_features, images_tensor, ori_sizes) 98 | return predictions 99 | -------------------------------------------------------------------------------- /mask_former/modeling/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py 3 | """ 4 | Modules to compute the matching cost and solve the corresponding LSAP. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from scipy.optimize import linear_sum_assignment 9 | from torch import nn 10 | 11 | 12 | def batch_dice_loss(inputs, targets): 13 | """ 14 | Compute the DICE loss, similar to generalized IOU for masks 15 | Args: 16 | inputs: A float tensor of arbitrary shape. 17 | The predictions for each example. 18 | targets: A float tensor with the same shape as inputs. Stores the binary 19 | classification label for each element in inputs 20 | (0 for the negative class and 1 for the positive class). 21 | """ 22 | inputs = inputs.sigmoid() 23 | inputs = inputs.flatten(1) 24 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) 25 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] 26 | loss = 1 - (numerator + 1) / (denominator + 1) 27 | return loss 28 | 29 | 30 | def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): 31 | """ 32 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 33 | Args: 34 | inputs: A float tensor of arbitrary shape. 35 | The predictions for each example. 36 | targets: A float tensor with the same shape as inputs. Stores the binary 37 | classification label for each element in inputs 38 | (0 for the negative class and 1 for the positive class). 39 | alpha: (optional) Weighting factor in range (0,1) to balance 40 | positive vs negative examples. Default = -1 (no weighting). 41 | gamma: Exponent of the modulating factor (1 - p_t) to 42 | balance easy vs hard examples. 43 | Returns: 44 | Loss tensor 45 | """ 46 | hw = inputs.shape[1] 47 | 48 | prob = inputs.sigmoid() 49 | focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits( 50 | inputs, torch.ones_like(inputs), reduction="none" 51 | ) 52 | focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits( 53 | inputs, torch.zeros_like(inputs), reduction="none" 54 | ) 55 | if alpha >= 0: 56 | focal_pos = focal_pos * alpha 57 | focal_neg = focal_neg * (1 - alpha) 58 | 59 | loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum( 60 | "nc,mc->nm", focal_neg, (1 - targets) 61 | ) 62 | 63 | return loss / hw 64 | 65 | 66 | class HungarianMatcher(nn.Module): 67 | """This class computes an assignment between the targets and the predictions of the network 68 | 69 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 70 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 71 | while the others are un-matched (and thus treated as non-objects). 72 | """ 73 | 74 | def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1): 75 | """Creates the matcher 76 | 77 | Params: 78 | cost_class: This is the relative weight of the classification error in the matching cost 79 | cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost 80 | cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost 81 | """ 82 | super().__init__() 83 | self.cost_class = cost_class 84 | self.cost_mask = cost_mask 85 | self.cost_dice = cost_dice 86 | assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" 87 | 88 | @torch.no_grad() 89 | def memory_efficient_forward(self, outputs, targets): 90 | """More memory-friendly matching""" 91 | bs, num_queries = outputs["pred_logits"].shape[:2] 92 | 93 | # Work out the mask padding size 94 | masks = [v["masks"] for v in targets] 95 | h_max = max([m.shape[1] for m in masks]) 96 | w_max = max([m.shape[2] for m in masks]) 97 | 98 | indices = [] 99 | 100 | # Iterate through batch size 101 | for b in range(bs): 102 | 103 | out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] 104 | out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] 105 | 106 | tgt_ids = targets[b]["labels"] 107 | # gt masks are already padded when preparing target 108 | tgt_mask = targets[b]["masks"].to(out_mask) 109 | 110 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 111 | # but approximate it in 1 - proba[target class]. 112 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 113 | cost_class = -out_prob[:, tgt_ids] 114 | 115 | # Downsample gt masks to save memory 116 | tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest") 117 | 118 | # Flatten spatial dimension 119 | out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W] 120 | tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W] 121 | 122 | # Compute the focal loss between masks 123 | cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask) 124 | 125 | # Compute the dice loss betwen masks 126 | cost_dice = batch_dice_loss(out_mask, tgt_mask) 127 | 128 | # Final cost matrix 129 | C = ( 130 | self.cost_mask * cost_mask 131 | + self.cost_class * cost_class 132 | + self.cost_dice * cost_dice 133 | ) 134 | C = C.reshape(num_queries, -1).cpu() 135 | 136 | indices.append(linear_sum_assignment(C)) 137 | return [ 138 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) 139 | for i, j in indices 140 | ] 141 | 142 | @torch.no_grad() 143 | def forward(self, outputs, targets): 144 | """Performs the matching 145 | 146 | Params: 147 | outputs: This is a dict that contains at least these entries: 148 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 149 | "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks 150 | 151 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 152 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 153 | objects in the target) containing the class labels 154 | "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks 155 | 156 | Returns: 157 | A list of size batch_size, containing tuples of (index_i, index_j) where: 158 | - index_i is the indices of the selected predictions (in order) 159 | - index_j is the indices of the corresponding selected targets (in order) 160 | For each batch element, it holds: 161 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 162 | """ 163 | return self.memory_efficient_forward(outputs, targets) 164 | 165 | def __repr__(self): 166 | head = "Matcher " + self.__class__.__name__ 167 | body = [ 168 | "cost_class: {}".format(self.cost_class), 169 | "cost_mask: {}".format(self.cost_mask), 170 | "cost_dice: {}".format(self.cost_dice), 171 | ] 172 | _repr_indent = 4 173 | lines = [head] + [" " * _repr_indent + line for line in body] 174 | return "\n".join(lines) 175 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/transformer_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | # Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py 4 | import fvcore.nn.weight_init as weight_init 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.layers import Conv2d 11 | 12 | from .position_encoding import PositionEmbeddingSine 13 | from .transformer import Transformer 14 | 15 | 16 | class TransformerPredictor(nn.Module): 17 | @configurable 18 | def __init__( 19 | self, 20 | in_channels, 21 | mask_classification=True, 22 | *, 23 | num_classes: int, 24 | hidden_dim: int, 25 | num_queries: int, 26 | nheads: int, 27 | dropout: float, 28 | dim_feedforward: int, 29 | enc_layers: int, 30 | dec_layers: int, 31 | pre_norm: bool, 32 | deep_supervision: bool, 33 | mask_dim: int, 34 | enforce_input_project: bool, 35 | ): 36 | """ 37 | NOTE: this interface is experimental. 38 | Args: 39 | in_channels: channels of the input features 40 | mask_classification: whether to add mask classifier or not 41 | num_classes: number of classes 42 | hidden_dim: Transformer feature dimension 43 | num_queries: number of queries 44 | nheads: number of heads 45 | dropout: dropout in Transformer 46 | dim_feedforward: feature dimension in feedforward network 47 | enc_layers: number of Transformer encoder layers 48 | dec_layers: number of Transformer decoder layers 49 | pre_norm: whether to use pre-LayerNorm or not 50 | deep_supervision: whether to add supervision to every decoder layers 51 | mask_dim: mask feature dimension 52 | enforce_input_project: add input project 1x1 conv even if input 53 | channels and hidden dim is identical 54 | """ 55 | super().__init__() 56 | 57 | self.mask_classification = mask_classification 58 | 59 | # positional encoding 60 | N_steps = hidden_dim // 2 61 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 62 | 63 | transformer = Transformer( 64 | d_model=hidden_dim, 65 | dropout=dropout, 66 | nhead=nheads, 67 | dim_feedforward=dim_feedforward, 68 | num_encoder_layers=enc_layers, 69 | num_decoder_layers=dec_layers, 70 | normalize_before=pre_norm, 71 | return_intermediate_dec=deep_supervision, 72 | ) 73 | 74 | self.num_queries = num_queries 75 | self.transformer = transformer 76 | hidden_dim = transformer.d_model 77 | 78 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 79 | 80 | if in_channels != hidden_dim or enforce_input_project: 81 | self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) 82 | weight_init.c2_xavier_fill(self.input_proj) 83 | else: 84 | self.input_proj = nn.Sequential() 85 | self.aux_loss = deep_supervision 86 | 87 | # output FFNs 88 | # TODO: change it with clip feature 89 | if self.mask_classification: 90 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 91 | self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) 92 | 93 | @classmethod 94 | def from_config(cls, cfg, in_channels, mask_classification): 95 | ret = {} 96 | ret["in_channels"] = in_channels 97 | ret["mask_classification"] = mask_classification 98 | 99 | ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES 100 | ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM 101 | ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES 102 | # Transformer parameters: 103 | ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS 104 | ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT 105 | ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD 106 | ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS 107 | ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS 108 | ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM 109 | ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 110 | ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ 111 | 112 | ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM 113 | 114 | return ret 115 | 116 | def forward(self, x, mask_features): 117 | pos = self.pe_layer(x) 118 | 119 | src = x 120 | mask = None 121 | hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) 122 | 123 | if self.mask_classification: 124 | outputs_class = self.class_embed(hs) 125 | out = {"pred_logits": outputs_class[-1]} 126 | else: 127 | out = {} 128 | 129 | if self.aux_loss: 130 | # [l, bs, queries, embed] 131 | mask_embed = self.mask_embed(hs) 132 | outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) 133 | out["pred_masks"] = outputs_seg_masks[-1] 134 | out["aux_outputs"] = self._set_aux_loss( 135 | outputs_class if self.mask_classification else None, outputs_seg_masks 136 | ) 137 | else: 138 | # FIXME h_boxes takes the last one computed, keep this in mind 139 | # [bs, queries, embed] 140 | mask_embed = self.mask_embed(hs[-1]) 141 | outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) 142 | out["pred_masks"] = outputs_seg_masks 143 | return out 144 | 145 | @torch.jit.unused 146 | def _set_aux_loss(self, outputs_class, outputs_seg_masks): 147 | # this is a workaround to make torchscript happy, as torchscript 148 | # doesn't support dictionary with non-homogeneous values, such 149 | # as a dict having both a Tensor and a list. 150 | if self.mask_classification: 151 | return [ 152 | {"pred_logits": a, "pred_masks": b} 153 | for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) 154 | ] 155 | else: 156 | return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] 157 | 158 | 159 | class MLP(nn.Module): 160 | """Very simple multi-layer perceptron (also called FFN)""" 161 | 162 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | 170 | def forward(self, x): 171 | for i, layer in enumerate(self.layers): 172 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 173 | return x 174 | -------------------------------------------------------------------------------- /mask_former/semantic_seg_zero.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import numpy as np 3 | from typing import Callable, Dict, List, Optional, Union, Tuple 4 | import fvcore.nn.weight_init as weight_init 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 11 | from detectron2.structures import ImageList 12 | from detectron2.utils.registry import Registry 13 | 14 | from detectron2.modeling.postprocessing import sem_seg_postprocess 15 | from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head 16 | from detectron2.modeling import SemanticSegmentor 17 | from detectron2.modeling.backbone import Backbone 18 | from detectron2.data import MetadataCatalog 19 | # from ..backbone import build_backbone 20 | # from ..postprocessing import sem_seg_postprocess 21 | # from .build import META_ARCH_REGISTRY 22 | 23 | # __all__ = ["SemanticSegmentor", "SEM_SEG_HEADS_REGISTRY", "SemSegFPNHead", "build_sem_seg_head"] 24 | # 25 | # 26 | # SEM_SEG_HEADS_REGISTRY = Registry("SEM_SEG_HEADS") 27 | # SEM_SEG_HEADS_REGISTRY.__doc__ = """ 28 | # Registry for semantic segmentation heads, which make semantic segmentation predictions 29 | # from feature maps. 30 | # """ 31 | 32 | 33 | @META_ARCH_REGISTRY.register() 34 | class SemanticSegmentorGzero(SemanticSegmentor): 35 | """ 36 | Main class for semantic segmentation architectures. 37 | """ 38 | @configurable 39 | def __init__( 40 | self, 41 | *, 42 | backbone: Backbone, 43 | sem_seg_head: nn.Module, 44 | pixel_mean: Tuple[float], 45 | pixel_std: Tuple[float], 46 | gzero_calibrate: float, 47 | metadata, 48 | ): 49 | super().__init__( 50 | backbone=backbone, 51 | sem_seg_head=sem_seg_head, 52 | pixel_mean=pixel_mean, 53 | pixel_std=pixel_std, 54 | ) 55 | self.gzero_calibrate = gzero_calibrate 56 | self.metadata = metadata 57 | @classmethod 58 | def from_config(cls, cfg): 59 | backbone = build_backbone(cfg) 60 | sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) 61 | return { 62 | "backbone": backbone, 63 | "sem_seg_head": sem_seg_head, 64 | "pixel_mean": cfg.MODEL.PIXEL_MEAN, 65 | "pixel_std": cfg.MODEL.PIXEL_STD, 66 | "gzero_calibrate": cfg.MODEL.SEM_SEG_HEAD.GZERO_CALIBRATE, 67 | "metadata": MetadataCatalog.get(cfg.DATASETS.TEST[0]) 68 | } 69 | def forward(self, batched_inputs): 70 | """ 71 | Args: 72 | batched_inputs: a list, batched outputs of :class:`DatasetMapper`. 73 | Each item in the list contains the inputs for one image. 74 | 75 | For now, each item in the list is a dict that contains: 76 | 77 | * "image": Tensor, image in (C, H, W) format. 78 | * "sem_seg": semantic segmentation ground truth 79 | * Other information that's included in the original dicts, such as: 80 | "height", "width" (int): the output resolution of the model (may be different 81 | from input resolution), used in inference. 82 | 83 | 84 | Returns: 85 | list[dict]: 86 | Each dict is the output for one input image. 87 | The dict contains one key "sem_seg" whose value is a 88 | Tensor that represents the 89 | per-pixel segmentation prediced by the head. 90 | The prediction has shape KxHxW that represents the logits of 91 | each class for each pixel. 92 | """ 93 | images = [x["image"].to(self.device) for x in batched_inputs] 94 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 95 | images = ImageList.from_tensors(images, self.backbone.size_divisibility) 96 | 97 | features = self.backbone(images.tensor) 98 | 99 | if "sem_seg" in batched_inputs[0]: 100 | targets = [x["sem_seg"].to(self.device) for x in batched_inputs] 101 | targets = ImageList.from_tensors( 102 | targets, self.backbone.size_divisibility, self.sem_seg_head.ignore_value 103 | ).tensor 104 | else: 105 | targets = None 106 | results, losses = self.sem_seg_head(features, targets) 107 | # Note: results are logits, instead of prob 108 | if self.training: 109 | return losses 110 | # import pdb; pdb.set_trace() 111 | processed_results = [] 112 | for result, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): 113 | height = input_per_image.get("height") 114 | width = input_per_image.get("width") 115 | # softmax 116 | # shape of result: [171, H, W] 117 | r = F.softmax(result, dim=0) 118 | # import pdb; pdb.set_trace() 119 | # r = sem_seg_postprocess(result, image_size, height, width) 120 | r = sem_seg_postprocess(r, image_size, height, width) 121 | # import ipdb; ipdb.set_trace() 122 | 123 | # gzero calibrate 124 | if self.gzero_calibrate > 0: 125 | # seen_classnames = self.sem_seg_head.class_texts 126 | # num_seen_classnames = len(seen_classnames) 127 | # r[:num_seen_classnames, :, :] = r[:num_seen_classnames, :, :] - self.gzero_calibrate 128 | val_extra_classes = self.metadata.val_extra_classes 129 | seen_indexes = [] 130 | for cls in self.metadata.stuff_classes: 131 | if cls not in val_extra_classes: 132 | seen_indexes.append(self.metadata.stuff_classes.index(cls)) 133 | r[seen_indexes, :, :] = r[seen_indexes, :, :] - self.gzero_calibrate 134 | processed_results.append({"sem_seg": r}) 135 | # logits are enough for per-pixel semantic segmentation inference. so the initial code do not transform it to prob 136 | 137 | # import pdb; pdb.set_trace() 138 | return processed_results -------------------------------------------------------------------------------- /mask_former/test_time_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | from itertools import count 4 | 5 | import numpy as np 6 | import torch 7 | from fvcore.transforms import HFlipTransform 8 | from torch import nn 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from detectron2.data.detection_utils import read_image 12 | from detectron2.modeling import DatasetMapperTTA 13 | 14 | __all__ = [ 15 | "SemanticSegmentorWithTTA", 16 | ] 17 | 18 | 19 | class SemanticSegmentorWithTTA(nn.Module): 20 | """ 21 | A SemanticSegmentor with test-time augmentation enabled. 22 | Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. 23 | """ 24 | 25 | def __init__(self, cfg, model, tta_mapper=None, batch_size=1): 26 | """ 27 | Args: 28 | cfg (CfgNode): 29 | model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. 30 | tta_mapper (callable): takes a dataset dict and returns a list of 31 | augmented versions of the dataset dict. Defaults to 32 | `DatasetMapperTTA(cfg)`. 33 | batch_size (int): batch the augmented images into this batch size for inference. 34 | """ 35 | super().__init__() 36 | if isinstance(model, DistributedDataParallel): 37 | model = model.module 38 | self.cfg = cfg.clone() 39 | 40 | self.model = model 41 | 42 | if tta_mapper is None: 43 | tta_mapper = DatasetMapperTTA(cfg) 44 | self.tta_mapper = tta_mapper 45 | self.batch_size = batch_size 46 | 47 | def _batch_inference(self, batched_inputs): 48 | """ 49 | Execute inference on a list of inputs, 50 | using batch size = self.batch_size, instead of the length of the list. 51 | Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward` 52 | """ 53 | outputs = [] 54 | inputs = [] 55 | for idx, input in zip(count(), batched_inputs): 56 | inputs.append(input) 57 | if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: 58 | with torch.no_grad(): 59 | outputs.extend(self.model(inputs)) 60 | inputs = [] 61 | return outputs 62 | 63 | def __call__(self, batched_inputs): 64 | """ 65 | Same input/output format as :meth:`SemanticSegmentor.forward` 66 | """ 67 | 68 | def _maybe_read_image(dataset_dict): 69 | ret = copy.copy(dataset_dict) 70 | if "image" not in ret: 71 | image = read_image(ret.pop("file_name"), self.model.input_format) 72 | image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW 73 | ret["image"] = image 74 | if "height" not in ret and "width" not in ret: 75 | ret["height"] = image.shape[1] 76 | ret["width"] = image.shape[2] 77 | return ret 78 | 79 | return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] 80 | 81 | def _inference_one_image(self, input): 82 | """ 83 | Args: 84 | input (dict): one dataset dict with "image" field being a CHW tensor 85 | Returns: 86 | dict: one output dict 87 | """ 88 | augmented_inputs, tfms = self._get_augmented_inputs(input) 89 | # 1: forward with all augmented images 90 | outputs = self._batch_inference(augmented_inputs) 91 | # Delete now useless variables to avoid being out of memory 92 | del augmented_inputs 93 | # 2: merge the results 94 | # handle flip specially 95 | new_outputs = [] 96 | for output, tfm in zip(outputs, tfms): 97 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 98 | new_outputs.append(output.pop("sem_seg").flip(dims=[2])) 99 | else: 100 | new_outputs.append(output.pop("sem_seg")) 101 | del outputs 102 | # to avoid OOM with torch.stack 103 | final_predictions = new_outputs[0] 104 | for i in range(1, len(new_outputs)): 105 | final_predictions += new_outputs[i] 106 | final_predictions = final_predictions / len(new_outputs) 107 | del new_outputs 108 | return {"sem_seg": final_predictions} 109 | 110 | def _get_augmented_inputs(self, input): 111 | augmented_inputs = self.tta_mapper(input) 112 | tfms = [x.pop("transforms") for x in augmented_inputs] 113 | return augmented_inputs, tfms 114 | -------------------------------------------------------------------------------- /mask_former/third_party/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingjiansw101/ZegFormer/8cb9e8fbddfbee80b21168726d4f76f571def761/mask_former/third_party/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /mask_former/third_party/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | __all__ = ["available_models", "load", "tokenize"] 16 | _tokenizer = _Tokenizer() 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 23 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 24 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 25 | } 26 | 27 | 28 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 29 | os.makedirs(root, exist_ok=True) 30 | filename = os.path.basename(url) 31 | 32 | expected_sha256 = url.split("/")[-2] 33 | download_target = os.path.join(root, filename) 34 | 35 | if os.path.exists(download_target) and not os.path.isfile(download_target): 36 | raise RuntimeError(f"{download_target} exists and is not a regular file") 37 | 38 | if os.path.isfile(download_target): 39 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 40 | return download_target 41 | else: 42 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 43 | 44 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 45 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: 46 | while True: 47 | buffer = source.read(8192) 48 | if not buffer: 49 | break 50 | 51 | output.write(buffer) 52 | loop.update(len(buffer)) 53 | 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 55 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 56 | 57 | return download_target 58 | 59 | 60 | def available_models(): 61 | return list(_MODELS.keys()) 62 | 63 | 64 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 65 | if name not in _MODELS: 66 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 67 | 68 | model_path = _download(_MODELS[name]) 69 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 70 | n_px = model.input_resolution.item() 71 | 72 | transform = Compose([ 73 | Resize(n_px, interpolation=Image.BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | if not jit: 81 | model = build_model(model.state_dict()).to(device) 82 | return model, transform 83 | 84 | # patch the device names 85 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 86 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 87 | 88 | def patch_device(module): 89 | graphs = [module.graph] if hasattr(module, "graph") else [] 90 | if hasattr(module, "forward1"): 91 | graphs.append(module.forward1.graph) 92 | 93 | for graph in graphs: 94 | for node in graph.findAllNodes("prim::Constant"): 95 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 96 | node.copyAttributes(device_node) 97 | 98 | model.apply(patch_device) 99 | patch_device(model.encode_image) 100 | patch_device(model.encode_text) 101 | 102 | # patch dtype to float32 on CPU 103 | if device == "cpu": 104 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 105 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 106 | float_node = float_input.node() 107 | 108 | def patch_float(module): 109 | graphs = [module.graph] if hasattr(module, "graph") else [] 110 | if hasattr(module, "forward1"): 111 | graphs.append(module.forward1.graph) 112 | 113 | for graph in graphs: 114 | for node in graph.findAllNodes("aten::to"): 115 | inputs = list(node.inputs()) 116 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 117 | if inputs[i].node()["value"] == 5: 118 | inputs[i].node().copyAttributes(float_node) 119 | 120 | model.apply(patch_float) 121 | patch_float(model.encode_image) 122 | patch_float(model.encode_text) 123 | 124 | model.float() 125 | 126 | return model, transform 127 | 128 | 129 | def load_custom(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, n_px=224): 130 | if name not in _MODELS: 131 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 132 | 133 | model_path = _download(_MODELS[name]) 134 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 135 | # n_px = model.input_resolution.item() 136 | 137 | transform = Compose([ 138 | Resize(n_px, interpolation=Image.BICUBIC), 139 | CenterCrop(n_px), 140 | lambda image: image.convert("RGB"), 141 | ToTensor(), 142 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 143 | ]) 144 | 145 | if not jit: 146 | model = build_model(model.state_dict()).to(device) 147 | return model, transform 148 | 149 | # patch the device names 150 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 151 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 152 | 153 | def patch_device(module): 154 | graphs = [module.graph] if hasattr(module, "graph") else [] 155 | if hasattr(module, "forward1"): 156 | graphs.append(module.forward1.graph) 157 | 158 | for graph in graphs: 159 | for node in graph.findAllNodes("prim::Constant"): 160 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 161 | node.copyAttributes(device_node) 162 | 163 | model.apply(patch_device) 164 | patch_device(model.encode_image) 165 | patch_device(model.encode_text) 166 | 167 | # patch dtype to float32 on CPU 168 | if device == "cpu": 169 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 170 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 171 | float_node = float_input.node() 172 | 173 | def patch_float(module): 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | if hasattr(module, "forward1"): 176 | graphs.append(module.forward1.graph) 177 | 178 | for graph in graphs: 179 | for node in graph.findAllNodes("aten::to"): 180 | inputs = list(node.inputs()) 181 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 182 | if inputs[i].node()["value"] == 5: 183 | inputs[i].node().copyAttributes(float_node) 184 | 185 | model.apply(patch_float) 186 | patch_float(model.encode_image) 187 | patch_float(model.encode_text) 188 | 189 | model.float() 190 | 191 | return model, transform 192 | 193 | def tokenize(texts: Union[str, List[str]], context_length: int = 77): 194 | if isinstance(texts, str): 195 | texts = [texts] 196 | 197 | sot_token = _tokenizer.encoder["<|startoftext|>"] 198 | eot_token = _tokenizer.encoder["<|endoftext|>"] 199 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 200 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 201 | 202 | for i, tokens in enumerate(all_tokens): 203 | if len(tokens) > context_length: 204 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 205 | result[i, :len(tokens)] = torch.tensor(tokens) 206 | 207 | return result 208 | -------------------------------------------------------------------------------- /mask_former/third_party/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /mask_former/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.distributed as dist 12 | import torchvision 13 | from torch import Tensor 14 | 15 | 16 | def _max_by_axis(the_list): 17 | # type: (List[List[int]]) -> List[int] 18 | maxes = the_list[0] 19 | for sublist in the_list[1:]: 20 | for index, item in enumerate(sublist): 21 | maxes[index] = max(maxes[index], item) 22 | return maxes 23 | 24 | 25 | class NestedTensor(object): 26 | def __init__(self, tensors, mask: Optional[Tensor]): 27 | self.tensors = tensors 28 | self.mask = mask 29 | 30 | def to(self, device): 31 | # type: (Device) -> NestedTensor # noqa 32 | cast_tensor = self.tensors.to(device) 33 | mask = self.mask 34 | if mask is not None: 35 | assert mask is not None 36 | cast_mask = mask.to(device) 37 | else: 38 | cast_mask = None 39 | return NestedTensor(cast_tensor, cast_mask) 40 | 41 | def decompose(self): 42 | return self.tensors, self.mask 43 | 44 | def __repr__(self): 45 | return str(self.tensors) 46 | 47 | 48 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 49 | # TODO make this more general 50 | if tensor_list[0].ndim == 3: 51 | if torchvision._is_tracing(): 52 | # nested_tensor_from_tensor_list() does not export well to ONNX 53 | # call _onnx_nested_tensor_from_tensor_list() instead 54 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 55 | 56 | # TODO make it support different-sized images 57 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 58 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 59 | batch_shape = [len(tensor_list)] + max_size 60 | b, c, h, w = batch_shape 61 | dtype = tensor_list[0].dtype 62 | device = tensor_list[0].device 63 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 64 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 65 | for img, pad_img, m in zip(tensor_list, tensor, mask): 66 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 67 | m[: img.shape[1], : img.shape[2]] = False 68 | else: 69 | raise ValueError("not supported") 70 | return NestedTensor(tensor, mask) 71 | 72 | 73 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 74 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 75 | @torch.jit.unused 76 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 77 | max_size = [] 78 | for i in range(tensor_list[0].dim()): 79 | max_size_i = torch.max( 80 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 81 | ).to(torch.int64) 82 | max_size.append(max_size_i) 83 | max_size = tuple(max_size) 84 | 85 | # work around for 86 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 87 | # m[: img.shape[1], :img.shape[2]] = False 88 | # which is not yet supported in onnx 89 | padded_imgs = [] 90 | padded_masks = [] 91 | for img in tensor_list: 92 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 93 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 94 | padded_imgs.append(padded_img) 95 | 96 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 97 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 98 | padded_masks.append(padded_mask.to(torch.bool)) 99 | 100 | tensor = torch.stack(padded_imgs) 101 | mask = torch.stack(padded_masks) 102 | 103 | return NestedTensor(tensor, mask=mask) 104 | 105 | 106 | def is_dist_avail_and_initialized(): 107 | if not dist.is_available(): 108 | return False 109 | if not dist.is_initialized(): 110 | return False 111 | return True 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.7.0 2 | timm==0.4.12 3 | ftfy==6.0.1 4 | opencv-python==4.5.1.48 5 | setuptools==59.5.0 6 | pillow==8.2.0 7 | imageio==2.4.1 8 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | This directory contains few tools to convert ImageNet pre-trained weights. 2 | 3 | * `convert-torchvision-to-d2.py` 4 | 5 | Tool to convert torchvision pre-trained weights for D2. 6 | 7 | ``` 8 | wget https://download.pytorch.org/models/resnet101-63fe2227.pth 9 | python tools/convert-torchvision-to-d2.py resnet101-63fe2227.pth R-101.pkl 10 | ``` 11 | -------------------------------------------------------------------------------- /tools/convert-pretrained-swin-model-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download pretrained swin model: 12 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 13 | # run the conversion 14 | ./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl 15 | # Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl" 18 | INPUT: 19 | FORMAT: "RGB" 20 | """ 21 | 22 | if __name__ == "__main__": 23 | input = sys.argv[1] 24 | 25 | obj = torch.load(input, map_location="cpu")["model"] 26 | 27 | res = {"model": obj, "__author__": "third_party", "matching_heuristics": True} 28 | 29 | with open(sys.argv[2], "wb") as f: 30 | pkl.dump(res, f) 31 | -------------------------------------------------------------------------------- /tools/convert-torchvision-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download one of the ResNet{18,34,50,101,152} models from torchvision: 12 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth 13 | # run the conversion 14 | ./convert-torchvision-to-d2.py r50.pth r50.pkl 15 | # Then, use r50.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/r50.pkl" 18 | PIXEL_MEAN: [123.675, 116.280, 103.530] 19 | PIXEL_STD: [58.395, 57.120, 57.375] 20 | RESNETS: 21 | DEPTH: 50 22 | STRIDE_IN_1X1: False 23 | INPUT: 24 | FORMAT: "RGB" 25 | These models typically produce slightly worse results than the 26 | pre-trained ResNets we use in official configs, which are the 27 | original ResNet models released by MSRA. 28 | """ 29 | 30 | if __name__ == "__main__": 31 | input = sys.argv[1] 32 | 33 | obj = torch.load(input, map_location="cpu") 34 | 35 | newmodel = {} 36 | for k in list(obj.keys()): 37 | old_k = k 38 | if "layer" not in k: 39 | k = "stem." + k 40 | for t in [1, 2, 3, 4]: 41 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 42 | for t in [1, 2, 3]: 43 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 44 | k = k.replace("downsample.0", "shortcut") 45 | k = k.replace("downsample.1", "shortcut.norm") 46 | print(old_k, "->", k) 47 | newmodel[k] = obj.pop(old_k).detach().numpy() 48 | 49 | res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} 50 | 51 | with open(sys.argv[2], "wb") as f: 52 | pkl.dump(res, f) 53 | if obj: 54 | print("Unconverted keys:", obj.keys()) 55 | -------------------------------------------------------------------------------- /tools/sem_seg_json2mat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.io import savemat 4 | import json 5 | from pycocotools import mask as maskUtils 6 | from collections import defaultdict 7 | import os 8 | from tqdm import tqdm 9 | from detectron2.data.detection_utils import read_image 10 | from PIL import Image 11 | 12 | # def json_to_mat(filename, outfile): 13 | # fin = open(filename, encoding='UTF-8') 14 | # s = json.load(fin) 15 | # data = dict() 16 | # for k, v in s.items(): 17 | # data[k] = v 18 | # savemat(outfile, data) 19 | # fin.close() 20 | 21 | def sem_seg_json_to_mat(filename, outdir=None, dataset_name="coco_2017_val_all_stuff_sem_seg"): 22 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 23 | assert dataset_name in ["coco_2017_val_all_stuff_sem_seg"] 24 | with open(filename, 'r') as f_in: 25 | predictions = json.load(f_in) 26 | 27 | if not os.path.exists(outdir): 28 | os.makedirs(outdir) 29 | 30 | imgToAnns = defaultdict(list) 31 | for pred in predictions: 32 | image_id = os.path.basename(pred["file_name"]).split(".")[0] 33 | imgToAnns[image_id].append( 34 | {"category_id": pred["category_id"], "segmentation": pred["segmentation"]} 35 | ) 36 | image_ids = list(imgToAnns.keys()) 37 | 38 | for image_id in tqdm(image_ids): 39 | if dataset_name == "coco_2017_val_all_stuff_sem_seg": 40 | gt_dir = os.path.join(_root, "coco/coco_stuff", "annotations_detectron2", "val2017_all") 41 | segm_gt = read_image(os.path.join(gt_dir, image_id + ".png")).copy().astype(np.int64) 42 | # import ipdb; ipdb.set_trace() 43 | # get predictions 44 | segm_dt = np.zeros_like(segm_gt) 45 | anns = imgToAnns[image_id] 46 | # import ipdb; ipdb.set_trace() 47 | for ann in anns: 48 | # map back category_id 49 | category_id = ann["category_id"] 50 | mask = maskUtils.decode(ann["segmentation"]) 51 | # TODO: keep it in imind, that ther id here just represent a partition, and not the real category_id 52 | segm_dt[mask > 0] = category_id + 1 53 | # import ipdb; 54 | # ipdb.set_trace() 55 | Image.fromarray(segm_dt.astype(np.uint16)).save(os.path.join(outdir, image_id + '.tif')) 56 | 57 | if __name__ == '__main__': 58 | # sem_seg_json_to_mat(r'/home/dj/code/MaskFormer/work_dirs/' 59 | # r'maskformer_R50_bs32_60k_zeroshot_gzss_eval_clipcls_vit16_coco-stuff' 60 | # r'/inference/sem_seg_predictions.json', 61 | # r'/home/dj/code/MaskFormer/work_dirs/' 62 | # r'maskformer_R50_bs32_60k_zeroshot_gzss_eval_clipcls_vit16_coco-stuff' 63 | # r'/inference/pngs') 64 | # sem_seg_json_to_mat(r'work_dirs/maskformer_R50_bs32_60k_zeroshot_vit16_gzss_eval_coco-stuff' 65 | # r'/inference/sem_seg_predictions.json', 66 | # r'work_dirs/maskformer_R50_bs32_60k_zeroshot_vit16_gzss_eval_coco-stuff' 67 | # r'/inference/images') 68 | # sem_seg_json_to_mat(r'work_dirs/per_pixel_baseline_R50_bs32_60k_zeroshot_gzss_eval' 69 | # r'/inference/sem_seg_predictions.json', 70 | # r'work_dirs/per_pixel_baseline_R50_bs32_60k_zeroshot_gzss_eval' 71 | # r'/inference/images') 72 | # sem_seg_json_to_mat(r'work_dirs/per_pixel_baseline_R50_bs32_60k_zeroshot_vit16_coco-stuff_gzss_eval' 73 | # r'/inference/sem_seg_predictions.json', 74 | # r'work_dirs/per_pixel_baseline_R50_bs32_60k_zeroshot_vit16_coco-stuff_gzss_eval' 75 | # r'/inference/images') 76 | 77 | # sem_seg_json_to_mat(r'work_dirs/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16_zeroshot_coco-stuff_gzss_eval' 78 | # r'/inference/sem_seg_predictions.json', 79 | # r'work_dirs/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16_zeroshot_coco-stuff_gzss_eval' 80 | # r'/inference/images') 81 | 82 | # sem_seg_json_to_mat(r'work_dirs/maskformer_R50_bs32_60k_zeroshot_vit16_gzss_eval_coco-stuff_groupeval' 83 | # r'/inference/sem_seg_predictions.json', 84 | # r'work_dirs/maskformer_R50_bs32_60k_zeroshot_vit16_gzss_eval_coco-stuff_groupeval' 85 | # r'/inference/images') 86 | 87 | # sem_seg_json_to_mat(r'work_dirs/maskformer_R50_bs32_60k_zeroshot_wordvec_coco-stuff_gzss_eval_group_eval' 88 | # r'/inference/sem_seg_predictions.json', 89 | # r'work_dirs/maskformer_R50_bs32_60k_zeroshot_wordvec_coco-stuff_gzss_eval_group_eval' 90 | # r'/inference/images') 91 | 92 | # sem_seg_json_to_mat(r'work_dirs/per_pixel_baseline_R50_bs32_60k_zeroshot_wordvec_coco-stuff_gzss_eval_group_eval' 93 | # r'/inference/sem_seg_predictions.json', 94 | # r'work_dirs/per_pixel_baseline_R50_bs32_60k_zeroshot_wordvec_coco-stuff_gzss_eval_group_eval' 95 | # r'/inference/images') 96 | 97 | sem_seg_json_to_mat(r'work_dirs/maskformer_R50_bs32_60k_zeroshot_vit16_gzss_eval_coco-stuff_group_eval_tx1' 98 | r'/inference/sem_seg_predictions.json', 99 | r'work_dirs/maskformer_R50_bs32_60k_zeroshot_vit16_gzss_eval_coco-stuff_group_eval_tx1' 100 | r'/inference/images') 101 | --------------------------------------------------------------------------------