├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GETTING_STARTED.md ├── INSTALL.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── configs ├── ade20k-150-panoptic │ ├── maskformer_panoptic_R101_bs16_720k.yaml │ └── maskformer_panoptic_R50_bs16_720k.yaml ├── ade20k-150 │ ├── Base-ADE20K-150.yaml │ ├── maskformer_R101_bs16_160k.yaml │ ├── maskformer_R101c_bs16_160k.yaml │ ├── maskformer_R50_bs16_160k.yaml │ ├── per_pixel_baseline_R50_bs16_160k.yaml │ ├── per_pixel_baseline_plus_R50_bs16_160k.yaml │ └── swin │ │ ├── maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml │ │ ├── maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml │ │ ├── maskformer_swin_small_bs16_160k.yaml │ │ └── maskformer_swin_tiny_bs16_160k.yaml ├── ade20k-full-847 │ ├── Base-ADE20KFull-847.yaml │ ├── maskformer_R101_bs16_200k.yaml │ ├── maskformer_R101c_bs16_200k.yaml │ ├── maskformer_R50_bs16_200k.yaml │ ├── per_pixel_baseline_R50_bs16_200k.yaml │ └── per_pixel_baseline_plus_R50_bs16_200k.yaml ├── cityscapes-19 │ ├── Base-Cityscapes-19.yaml │ ├── maskformer_R101_bs16_90k.yaml │ └── maskformer_R101c_bs16_90k.yaml ├── coco-panoptic │ ├── Base-COCO-PanopticSegmentation.yaml │ ├── maskformer_panoptic_R101_bs64_554k.yaml │ ├── maskformer_panoptic_R50_bs64_554k.yaml │ └── swin │ │ ├── maskformer_panoptic_swin_base_IN21k_384_bs64_554k.yaml │ │ ├── maskformer_panoptic_swin_large_IN21k_384_bs64_554k.yaml │ │ ├── maskformer_panoptic_swin_small_bs64_554k.yaml │ │ └── maskformer_panoptic_swin_tiny_bs64_554k.yaml ├── coco-stuff-10k-171 │ ├── Base-COCOStuff10K-171.yaml │ ├── maskformer_R101_bs32_60k.yaml │ ├── maskformer_R101c_bs32_60k.yaml │ ├── maskformer_R50_bs32_60k.yaml │ ├── per_pixel_baseline_R50_bs32_60k.yaml │ └── per_pixel_baseline_plus_R50_bs32_60k.yaml └── mapillary-vistas-65 │ ├── Base-MapillaryVistas-65.yaml │ └── maskformer_R50_bs16_300k.yaml ├── datasets ├── README.md ├── ade20k_instance_catid_mapping.txt ├── ade20k_instance_imgCatIds.json ├── prepare_ade20k_full_sem_seg.py ├── prepare_ade20k_pan_seg.py ├── prepare_ade20k_sem_seg.py └── prepare_coco_stuff_10k_v1.0_sem_seg.py ├── demo ├── README.md ├── demo.py └── predictor.py ├── mask_former ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ ├── dataset_mappers │ │ ├── __init__.py │ │ ├── detr_panoptic_dataset_mapper.py │ │ ├── mask_former_panoptic_dataset_mapper.py │ │ └── mask_former_semantic_dataset_mapper.py │ └── datasets │ │ ├── __init__.py │ │ ├── register_ade20k_full.py │ │ ├── register_ade20k_panoptic.py │ │ ├── register_coco_stuff_10k.py │ │ └── register_mapillary_vistas.py ├── mask_former_model.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ └── swin.py │ ├── criterion.py │ ├── heads │ │ ├── __init__.py │ │ ├── mask_former_head.py │ │ ├── per_pixel_baseline.py │ │ └── pixel_decoder.py │ ├── matcher.py │ └── transformer │ │ ├── __init__.py │ │ ├── position_encoding.py │ │ ├── transformer.py │ │ └── transformer_predictor.py ├── test_time_augmentation.py └── utils │ ├── __init__.py │ └── misc.py ├── requirements.txt ├── tools ├── README.md ├── analyze_model.py ├── convert-pretrained-swin-model-to-d2.py ├── convert-torchvision-to-d2.py └── evaluate_pq_for_semantic_segmentation.py └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | instant_test_output 4 | inference_test_output 5 | 6 | 7 | *.png 8 | *.json 9 | *.diff 10 | *.jpg 11 | !/projects/DensePose/doc/images/*.jpg 12 | 13 | # compilation and distribution 14 | __pycache__ 15 | _ext 16 | *.pyc 17 | *.pyd 18 | *.so 19 | *.dll 20 | *.egg-info/ 21 | build/ 22 | dist/ 23 | wheels/ 24 | 25 | # pytorch/python/numpy formats 26 | *.pth 27 | *.pkl 28 | *.npy 29 | *.ts 30 | model_ts*.txt 31 | 32 | # ipython/jupyter notebooks 33 | *.ipynb 34 | **/.ipynb_checkpoints/ 35 | 36 | # Editor temporaries 37 | *.swn 38 | *.swo 39 | *.swp 40 | *~ 41 | 42 | # editor settings 43 | .idea 44 | .vscode 45 | _darcs 46 | 47 | # project dirs 48 | /detectron2/model_zoo/configs 49 | /datasets/* 50 | !/datasets/*.* 51 | /projects/*/datasets 52 | /models 53 | /snippet -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MaskFormer 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 4 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) 36 | 37 | ## License 38 | By contributing to MaskFormer, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with MaskFormer 2 | 3 | This document provides a brief intro of the usage of MaskFormer. 4 | 5 | Please see [Getting Started with Detectron2](https://github.com/facebookresearch/detectron2/blob/master/GETTING_STARTED.md) for full usage. 6 | 7 | 8 | ### Inference Demo with Pre-trained Models 9 | 10 | 1. Pick a model and its config file from 11 | [model zoo](MODEL_ZOO.md), 12 | for example, `ade20k-150/maskformer_R50_bs16_160k.yaml`. 13 | 2. We provide `demo.py` that is able to demo builtin configs. Run it with: 14 | ``` 15 | cd demo/ 16 | python demo.py --config-file ../configs/ade20k-150/maskformer_R50_bs16_160k.yaml \ 17 | --input input1.jpg input2.jpg \ 18 | [--other-options] 19 | --opts MODEL.WEIGHTS /path/to/checkpoint_file 20 | ``` 21 | The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation. 22 | This command will run the inference and show visualizations in an OpenCV window. 23 | 24 | For details of the command line arguments, see `demo.py -h` or look at its source code 25 | to understand its behavior. Some common arguments are: 26 | * To run __on your webcam__, replace `--input files` with `--webcam`. 27 | * To run __on a video__, replace `--input files` with `--video-input video.mp4`. 28 | * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`. 29 | * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`. 30 | 31 | 32 | ### Training & Evaluation in Command Line 33 | 34 | We provide two scripts in `train_net.py`, that are made to train all the configs provided in MaskFormer. 35 | 36 | To train a model with "train_net.py", first 37 | setup the corresponding datasets following 38 | [datasets/README.md](./datasets/README.md), 39 | then run: 40 | ``` 41 | ./train_net.py --num-gpus 8 \ 42 | --config-file configs/ade20k-150/maskformer_R50_bs16_160k.yaml 43 | ``` 44 | 45 | The configs are made for 8-GPU training. 46 | Since we use ADAMW optimizer, it is not clear how to scale learning rate with batch size. 47 | To train on 1 GPU, you need to figure out learning rate and batch size by yourself: 48 | ``` 49 | ./train_net.py \ 50 | --config-file configs/ade20k-150/maskformer_R50_bs16_160k.yaml \ 51 | --num-gpus 1 SOLVER.IMS_PER_BATCH SET_TO_SOME_REASONABLE_VALUE SOLVER.BASE_LR SET_TO_SOME_REASONABLE_VALUE 52 | ``` 53 | 54 | To evaluate a model's performance, use 55 | ``` 56 | ./train_net.py \ 57 | --config-file configs/ade20k-150/maskformer_R50_bs16_160k.yaml \ 58 | --eval-only MODEL.WEIGHTS /path/to/checkpoint_file 59 | ``` 60 | For more options, see `./train_net.py -h`. 61 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Requirements 4 | - Linux or macOS with Python ≥ 3.6 5 | - PyTorch ≥ 1.7 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 6 | Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check 7 | PyTorch version matches that is required by Detectron2. 8 | - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). 9 | - OpenCV is optional but needed by demo and visualization 10 | - `pip install -r requirements.txt` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation 2 | 3 | [Bowen Cheng](https://bowenc0221.github.io/), [Alexander G. Schwing](https://alexander-schwing.de/), [Alexander Kirillov](https://alexander-kirillov.github.io/) 4 | 5 | [[`arXiv`](http://arxiv.org/abs/2107.06278)] [[`Project`](https://bowenc0221.github.io/maskformer)] [[`BibTeX`](#CitingMaskFormer)] 6 | 7 |
8 | 9 |

10 | 11 | ### Mask2Former 12 | Checkout [Mask2Former](https://github.com/facebookresearch/Mask2Former), a universal architecture based on MaskFormer meta-architecture that 13 | achieves SOTA on panoptic, instance and semantic segmentation across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). 14 | 15 | ### Features 16 | * Better results while being more efficient. 17 | * Unified view of semantic- and instance-level segmentation tasks. 18 | * Support major semantic segmentation datasets: ADE20K, Cityscapes, COCO-Stuff, Mapillary Vistas. 19 | * Support **ALL** Detectron2 models. 20 | 21 | ## Installation 22 | 23 | See [installation instructions](INSTALL.md). 24 | 25 | ## Getting Started 26 | 27 | See [Preparing Datasets for MaskFormer](datasets/README.md). 28 | 29 | See [Getting Started with MaskFormer](GETTING_STARTED.md). 30 | 31 | ## Model Zoo and Baselines 32 | 33 | We provide a large set of baseline results and trained models available for download in the [MaskFormer Model Zoo](MODEL_ZOO.md). 34 | 35 | ## License 36 | 37 | Shield: [![CC BY-NC 4.0][cc-by-nc-shield]][cc-by-nc] 38 | 39 | The majority of MaskFormer is licensed under a 40 | [Creative Commons Attribution-NonCommercial 4.0 International License](LICENSE). 41 | 42 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc] 43 | 44 | [cc-by-nc]: http://creativecommons.org/licenses/by-nc/4.0/ 45 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png 46 | [cc-by-nc-shield]: https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg 47 | 48 | 49 | However portions of the project are available under separate license terms: Swin-Transformer-Semantic-Segmentation is licensed under the [MIT license](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/LICENSE). 50 | 51 | ## Citing MaskFormer 52 | 53 | If you use MaskFormer in your research or wish to refer to the baseline results published in the [Model Zoo](MODEL_ZOO.md), please use the following BibTeX entry. 54 | 55 | ```BibTeX 56 | @inproceedings{cheng2021maskformer, 57 | title={Per-Pixel Classification is Not All You Need for Semantic Segmentation}, 58 | author={Bowen Cheng and Alexander G. Schwing and Alexander Kirillov}, 59 | journal={NeurIPS}, 60 | year={2021} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /configs/ade20k-150-panoptic/maskformer_panoptic_R101_bs16_720k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_panoptic_R50_bs16_720k.yaml 2 | MODEL: 3 | WEIGHTS: "R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | STEM_TYPE: "basic" # not used 7 | STEM_OUT_CHANNELS: 64 8 | STRIDE_IN_1X1: False 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | # NORM: "SyncBN" 11 | RES5_MULTI_GRID: [1, 1, 1] # not used 12 | -------------------------------------------------------------------------------- /configs/ade20k-150-panoptic/maskformer_panoptic_R50_bs16_720k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../ade20k-150/maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | SEM_SEG_HEAD: 4 | PIXEL_DECODER_NAME: "TransformerEncoderPixelDecoder" 5 | TRANSFORMER_ENC_LAYERS: 6 6 | MASK_FORMER: 7 | TRANSFORMER_IN_FEATURE: "transformer_encoder" 8 | TEST: 9 | PANOPTIC_ON: True 10 | OVERLAP_THRESHOLD: 0.8 11 | OBJECT_MASK_THRESHOLD: 0.7 12 | DATASETS: 13 | TRAIN: ("ade20k_panoptic_train",) 14 | TEST: ("ade20k_panoptic_val",) 15 | SOLVER: 16 | MAX_ITER: 720000 17 | INPUT: 18 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] 19 | MIN_SIZE_TRAIN_SAMPLING: "choice" 20 | MIN_SIZE_TEST: 640 21 | MAX_SIZE_TRAIN: 2560 22 | MAX_SIZE_TEST: 2560 23 | CROP: 24 | ENABLED: True 25 | TYPE: "absolute" 26 | SIZE: (640, 640) 27 | SINGLE_CATEGORY_MAX_AREA: 1.0 28 | COLOR_AUG_SSD: True 29 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 30 | FORMAT: "RGB" 31 | DATASET_MAPPER_NAME: "mask_former_panoptic" 32 | TEST: 33 | EVAL_PERIOD: 0 34 | -------------------------------------------------------------------------------- /configs/ade20k-150/Base-ADE20K-150.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("ade20k_sem_seg_train",) 18 | TEST: ("ade20k_sem_seg_val",) 19 | SOLVER: 20 | IMS_PER_BATCH: 16 21 | BASE_LR: 0.0001 22 | MAX_ITER: 160000 23 | WARMUP_FACTOR: 1.0 24 | WARMUP_ITERS: 0 25 | WEIGHT_DECAY: 0.0001 26 | OPTIMIZER: "ADAMW" 27 | LR_SCHEDULER_NAME: "WarmupPolyLR" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"] 36 | MIN_SIZE_TRAIN_SAMPLING: "choice" 37 | MIN_SIZE_TEST: 512 38 | MAX_SIZE_TRAIN: 2048 39 | MAX_SIZE_TEST: 2048 40 | CROP: 41 | ENABLED: True 42 | TYPE: "absolute" 43 | SIZE: (512, 512) 44 | SINGLE_CATEGORY_MAX_AREA: 1.0 45 | COLOR_AUG_SSD: True 46 | SIZE_DIVISIBILITY: 512 # used in dataset mapper 47 | FORMAT: "RGB" 48 | DATASET_MAPPER_NAME: "mask_former_semantic" 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | AUG: 52 | ENABLED: False 53 | MIN_SIZES: [256, 384, 512, 640, 768, 896] 54 | MAX_SIZE: 3584 55 | FLIP: True 56 | DATALOADER: 57 | FILTER_EMPTY_ANNOTATIONS: True 58 | NUM_WORKERS: 4 59 | VERSION: 2 60 | -------------------------------------------------------------------------------- /configs/ade20k-150/maskformer_R101_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | WEIGHTS: "R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | STEM_TYPE: "basic" # not used 7 | STEM_OUT_CHANNELS: 64 8 | STRIDE_IN_1X1: False 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | # NORM: "SyncBN" 11 | RES5_MULTI_GRID: [1, 1, 1] # not used 12 | -------------------------------------------------------------------------------- /configs/ade20k-150/maskformer_R101c_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "detectron2://DeepLab/R-103.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "deeplab" 9 | STEM_OUT_CHANNELS: 128 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 2, 4] 14 | -------------------------------------------------------------------------------- /configs/ade20k-150/maskformer_R50_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20K-150.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "MaskFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 150 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | NO_OBJECT_WEIGHT: 0.1 18 | DICE_WEIGHT: 1.0 19 | MASK_WEIGHT: 20.0 20 | HIDDEN_DIM: 256 21 | NUM_OBJECT_QUERIES: 100 22 | NHEADS: 8 23 | DROPOUT: 0.1 24 | DIM_FEEDFORWARD: 2048 25 | ENC_LAYERS: 0 26 | DEC_LAYERS: 6 27 | PRE_NORM: False 28 | -------------------------------------------------------------------------------- /configs/ade20k-150/per_pixel_baseline_R50_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20K-150.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | SEM_SEG_HEAD: 5 | NAME: "PerPixelBaselineHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 150 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | -------------------------------------------------------------------------------- /configs/ade20k-150/per_pixel_baseline_plus_R50_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20K-150.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | SEM_SEG_HEAD: 5 | NAME: "PerPixelBaselinePlusHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 150 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | HIDDEN_DIM: 256 18 | NUM_OBJECT_QUERIES: 150 # remember to set this to NUM_CLASSES 19 | NHEADS: 8 20 | DROPOUT: 0.1 21 | DIM_FEEDFORWARD: 2048 22 | ENC_LAYERS: 0 23 | DEC_LAYERS: 6 24 | PRE_NORM: False 25 | -------------------------------------------------------------------------------- /configs/ade20k-150/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [2, 2, 18, 2] 8 | NUM_HEADS: [4, 8, 16, 32] 9 | WINDOW_SIZE: 12 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | PRETRAIN_IMG_SIZE: 384 14 | WEIGHTS: "swin_base_patch4_window12_384_22k.pkl" 15 | PIXEL_MEAN: [123.675, 116.280, 103.530] 16 | PIXEL_STD: [58.395, 57.120, 57.375] 17 | SOLVER: 18 | BASE_LR: 0.00006 19 | WARMUP_FACTOR: 1e-6 20 | WARMUP_ITERS: 1500 21 | WEIGHT_DECAY: 0.01 22 | WEIGHT_DECAY_NORM: 0.0 23 | WEIGHT_DECAY_EMBED: 0.0 24 | BACKBONE_MULTIPLIER: 1.0 25 | INPUT: 26 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] 27 | MIN_SIZE_TRAIN_SAMPLING: "choice" 28 | MIN_SIZE_TEST: 640 29 | MAX_SIZE_TRAIN: 2560 30 | MAX_SIZE_TEST: 2560 31 | CROP: 32 | ENABLED: True 33 | TYPE: "absolute" 34 | SIZE: (640, 640) 35 | SINGLE_CATEGORY_MAX_AREA: 1.0 36 | COLOR_AUG_SSD: True 37 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 38 | FORMAT: "RGB" 39 | TEST: 40 | EVAL_PERIOD: 5000 41 | AUG: 42 | ENABLED: False 43 | MIN_SIZES: [320, 480, 640, 800, 960, 1120] 44 | MAX_SIZE: 4480 45 | FLIP: True 46 | -------------------------------------------------------------------------------- /configs/ade20k-150/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [2, 2, 18, 2] 8 | NUM_HEADS: [6, 12, 24, 48] 9 | WINDOW_SIZE: 12 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | PRETRAIN_IMG_SIZE: 384 14 | WEIGHTS: "swin_large_patch4_window12_384_22k.pkl" 15 | PIXEL_MEAN: [123.675, 116.280, 103.530] 16 | PIXEL_STD: [58.395, 57.120, 57.375] 17 | SOLVER: 18 | BASE_LR: 0.00006 19 | WARMUP_FACTOR: 1e-6 20 | WARMUP_ITERS: 1500 21 | WEIGHT_DECAY: 0.01 22 | WEIGHT_DECAY_NORM: 0.0 23 | WEIGHT_DECAY_EMBED: 0.0 24 | BACKBONE_MULTIPLIER: 1.0 25 | INPUT: 26 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] 27 | MIN_SIZE_TRAIN_SAMPLING: "choice" 28 | MIN_SIZE_TEST: 640 29 | MAX_SIZE_TRAIN: 2560 30 | MAX_SIZE_TEST: 2560 31 | CROP: 32 | ENABLED: True 33 | TYPE: "absolute" 34 | SIZE: (640, 640) 35 | SINGLE_CATEGORY_MAX_AREA: 1.0 36 | COLOR_AUG_SSD: True 37 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 38 | FORMAT: "RGB" 39 | TEST: 40 | EVAL_PERIOD: 5000 41 | AUG: 42 | ENABLED: False 43 | MIN_SIZES: [320, 480, 640, 800, 960, 1120] 44 | MAX_SIZE: 4480 45 | FLIP: True 46 | -------------------------------------------------------------------------------- /configs/ade20k-150/swin/maskformer_swin_small_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [2, 2, 18, 2] 8 | NUM_HEADS: [3, 6, 12, 24] 9 | WINDOW_SIZE: 7 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | WEIGHTS: "swin_small_patch4_window7_224.pkl" 14 | PIXEL_MEAN: [123.675, 116.280, 103.530] 15 | PIXEL_STD: [58.395, 57.120, 57.375] 16 | SOLVER: 17 | BASE_LR: 0.00006 18 | WARMUP_FACTOR: 1e-6 19 | WARMUP_ITERS: 1500 20 | WEIGHT_DECAY: 0.01 21 | WEIGHT_DECAY_NORM: 0.0 22 | WEIGHT_DECAY_EMBED: 0.0 23 | BACKBONE_MULTIPLIER: 1.0 24 | -------------------------------------------------------------------------------- /configs/ade20k-150/swin/maskformer_swin_tiny_bs16_160k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_R50_bs16_160k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [2, 2, 6, 2] 8 | NUM_HEADS: [3, 6, 12, 24] 9 | WINDOW_SIZE: 7 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | WEIGHTS: "swin_tiny_patch4_window7_224.pkl" 14 | PIXEL_MEAN: [123.675, 116.280, 103.530] 15 | PIXEL_STD: [58.395, 57.120, 57.375] 16 | SOLVER: 17 | BASE_LR: 0.00006 18 | WARMUP_FACTOR: 1e-6 19 | WARMUP_ITERS: 1500 20 | WEIGHT_DECAY: 0.01 21 | WEIGHT_DECAY_NORM: 0.0 22 | WEIGHT_DECAY_EMBED: 0.0 23 | BACKBONE_MULTIPLIER: 1.0 24 | -------------------------------------------------------------------------------- /configs/ade20k-full-847/Base-ADE20KFull-847.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("ade20k_full_sem_seg_train",) 18 | TEST: ("ade20k_full_sem_seg_val",) 19 | SOLVER: 20 | IMS_PER_BATCH: 16 21 | BASE_LR: 0.0001 22 | MAX_ITER: 200000 23 | WARMUP_FACTOR: 1.0 24 | WARMUP_ITERS: 0 25 | WEIGHT_DECAY: 0.0001 26 | OPTIMIZER: "ADAMW" 27 | LR_SCHEDULER_NAME: "WarmupPolyLR" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"] 36 | MIN_SIZE_TRAIN_SAMPLING: "choice" 37 | MIN_SIZE_TEST: 512 38 | MAX_SIZE_TRAIN: 2048 39 | MAX_SIZE_TEST: 2048 40 | CROP: 41 | ENABLED: True 42 | TYPE: "absolute" 43 | SIZE: (512, 512) 44 | SINGLE_CATEGORY_MAX_AREA: 1.0 45 | COLOR_AUG_SSD: True 46 | SIZE_DIVISIBILITY: 512 # used in dataset mapper 47 | FORMAT: "RGB" 48 | DATASET_MAPPER_NAME: "mask_former_semantic" 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | DATALOADER: 52 | FILTER_EMPTY_ANNOTATIONS: True 53 | NUM_WORKERS: 4 54 | VERSION: 2 55 | -------------------------------------------------------------------------------- /configs/ade20k-full-847/maskformer_R101_bs16_200k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R50_bs16_200k.yaml 2 | MODEL: 3 | WEIGHTS: "R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | STEM_TYPE: "basic" # not used 7 | STEM_OUT_CHANNELS: 64 8 | STRIDE_IN_1X1: False 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | # NORM: "SyncBN" 11 | RES5_MULTI_GRID: [1, 1, 1] # not used 12 | -------------------------------------------------------------------------------- /configs/ade20k-full-847/maskformer_R101c_bs16_200k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R50_bs16_200k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "detectron2://DeepLab/R-103.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "deeplab" 9 | STEM_OUT_CHANNELS: 128 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 2, 4] 14 | -------------------------------------------------------------------------------- /configs/ade20k-full-847/maskformer_R50_bs16_200k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20KFull-847.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "MaskFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 65535 8 | NUM_CLASSES: 847 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | NO_OBJECT_WEIGHT: 0.1 18 | DICE_WEIGHT: 1.0 19 | MASK_WEIGHT: 20.0 20 | HIDDEN_DIM: 256 21 | NUM_OBJECT_QUERIES: 100 22 | NHEADS: 8 23 | DROPOUT: 0.1 24 | DIM_FEEDFORWARD: 2048 25 | ENC_LAYERS: 0 26 | DEC_LAYERS: 6 27 | PRE_NORM: False 28 | -------------------------------------------------------------------------------- /configs/ade20k-full-847/per_pixel_baseline_R50_bs16_200k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20KFull-847.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | SEM_SEG_HEAD: 5 | NAME: "PerPixelBaselineHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 65535 8 | NUM_CLASSES: 847 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | -------------------------------------------------------------------------------- /configs/ade20k-full-847/per_pixel_baseline_plus_R50_bs16_200k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-ADE20KFull-847.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | SEM_SEG_HEAD: 5 | NAME: "PerPixelBaselinePlusHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 65535 8 | NUM_CLASSES: 847 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | HIDDEN_DIM: 256 18 | NUM_OBJECT_QUERIES: 847 # remember to set this to NUM_CLASSES 19 | NHEADS: 8 20 | DROPOUT: 0.1 21 | DIM_FEEDFORWARD: 2048 22 | ENC_LAYERS: 0 23 | DEC_LAYERS: 6 24 | PRE_NORM: False 25 | -------------------------------------------------------------------------------- /configs/cityscapes-19/Base-Cityscapes-19.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("cityscapes_fine_sem_seg_train",) 18 | TEST: ("cityscapes_fine_sem_seg_val",) 19 | SOLVER: 20 | IMS_PER_BATCH: 16 21 | BASE_LR: 0.0001 22 | MAX_ITER: 90000 23 | WARMUP_FACTOR: 1.0 24 | WARMUP_ITERS: 0 25 | WEIGHT_DECAY: 0.0001 26 | OPTIMIZER: "ADAMW" 27 | LR_SCHEDULER_NAME: "WarmupPolyLR" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 1024) for x in range(5, 21)]"] 36 | MIN_SIZE_TRAIN_SAMPLING: "choice" 37 | MIN_SIZE_TEST: 1024 38 | MAX_SIZE_TRAIN: 4096 39 | MAX_SIZE_TEST: 2048 40 | CROP: 41 | ENABLED: True 42 | TYPE: "absolute" 43 | SIZE: (512, 1024) 44 | SINGLE_CATEGORY_MAX_AREA: 1.0 45 | COLOR_AUG_SSD: True 46 | SIZE_DIVISIBILITY: -1 47 | FORMAT: "RGB" 48 | DATASET_MAPPER_NAME: "mask_former_semantic" 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | AUG: 52 | ENABLED: False 53 | MIN_SIZES: [512, 768, 1024, 1280, 1536, 1792] 54 | MAX_SIZE: 4096 55 | FLIP: True 56 | DATALOADER: 57 | FILTER_EMPTY_ANNOTATIONS: True 58 | NUM_WORKERS: 4 59 | VERSION: 2 60 | -------------------------------------------------------------------------------- /configs/cityscapes-19/maskformer_R101_bs16_90k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-Cityscapes-19.yaml 2 | MODEL: 3 | WEIGHTS: "R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | STEM_TYPE: "basic" # not used 7 | STEM_OUT_CHANNELS: 64 8 | STRIDE_IN_1X1: False 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | # NORM: "SyncBN" 11 | RES5_MULTI_GRID: [1, 1, 1] # not used 12 | META_ARCHITECTURE: "MaskFormer" 13 | SEM_SEG_HEAD: 14 | NAME: "MaskFormerHead" 15 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 16 | IGNORE_VALUE: 255 17 | NUM_CLASSES: 19 18 | COMMON_STRIDE: 4 # not used, hard-coded 19 | LOSS_WEIGHT: 1.0 20 | CONVS_DIM: 256 21 | MASK_DIM: 256 22 | NORM: "GN" 23 | MASK_FORMER: 24 | TRANSFORMER_IN_FEATURE: "res5" 25 | DEEP_SUPERVISION: True 26 | NO_OBJECT_WEIGHT: 0.1 27 | DICE_WEIGHT: 1.0 28 | MASK_WEIGHT: 20.0 29 | HIDDEN_DIM: 256 30 | NUM_OBJECT_QUERIES: 100 31 | NHEADS: 8 32 | DROPOUT: 0.1 33 | DIM_FEEDFORWARD: 2048 34 | ENC_LAYERS: 0 35 | DEC_LAYERS: 6 36 | PRE_NORM: False 37 | -------------------------------------------------------------------------------- /configs/cityscapes-19/maskformer_R101c_bs16_90k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R101_bs16_90k.yaml 2 | MODEL: 3 | BACKBONE: 4 | FREEZE_AT: 0 5 | NAME: "build_resnet_deeplab_backbone" 6 | WEIGHTS: "detectron2://DeepLab/R-103.pkl" 7 | PIXEL_MEAN: [123.675, 116.280, 103.530] 8 | PIXEL_STD: [58.395, 57.120, 57.375] 9 | RESNETS: 10 | DEPTH: 101 11 | STEM_TYPE: "deeplab" 12 | STEM_OUT_CHANNELS: 128 13 | STRIDE_IN_1X1: False 14 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 15 | # NORM: "SyncBN" 16 | RES5_MULTI_GRID: [1, 2, 4] 17 | -------------------------------------------------------------------------------- /configs/coco-panoptic/Base-COCO-PanopticSegmentation.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("coco_2017_train_panoptic",) 18 | TEST: ("coco_2017_val_panoptic",) 19 | SOLVER: 20 | IMS_PER_BATCH: 64 21 | BASE_LR: 0.0001 22 | STEPS: (369600,) 23 | MAX_ITER: 554400 24 | WARMUP_FACTOR: 1.0 25 | WARMUP_ITERS: 10 26 | WEIGHT_DECAY: 0.0001 27 | OPTIMIZER: "ADAMW" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 36 | CROP: 37 | ENABLED: True 38 | TYPE: "absolute_range" 39 | SIZE: (384, 600) 40 | FORMAT: "RGB" 41 | DATASET_MAPPER_NAME: "detr_panoptic" 42 | TEST: 43 | EVAL_PERIOD: 0 44 | DATALOADER: 45 | FILTER_EMPTY_ANNOTATIONS: True 46 | NUM_WORKERS: 4 47 | VERSION: 2 48 | -------------------------------------------------------------------------------- /configs/coco-panoptic/maskformer_panoptic_R101_bs64_554k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_panoptic_R50_bs64_554k.yaml 2 | MODEL: 3 | WEIGHTS: "R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | STEM_TYPE: "basic" # not used 7 | STEM_OUT_CHANNELS: 64 8 | STRIDE_IN_1X1: False 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | # NORM: "SyncBN" 11 | RES5_MULTI_GRID: [1, 1, 1] # not used 12 | -------------------------------------------------------------------------------- /configs/coco-panoptic/maskformer_panoptic_R50_bs64_554k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-COCO-PanopticSegmentation.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "MaskFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 133 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | # add additional 6 encoder layers 15 | PIXEL_DECODER_NAME: "TransformerEncoderPixelDecoder" 16 | TRANSFORMER_ENC_LAYERS: 6 17 | MASK_FORMER: 18 | TRANSFORMER_IN_FEATURE: "transformer_encoder" 19 | DEEP_SUPERVISION: True 20 | NO_OBJECT_WEIGHT: 0.1 21 | DICE_WEIGHT: 1.0 22 | MASK_WEIGHT: 20.0 23 | HIDDEN_DIM: 256 24 | NUM_OBJECT_QUERIES: 100 25 | NHEADS: 8 26 | DROPOUT: 0.1 27 | DIM_FEEDFORWARD: 2048 28 | ENC_LAYERS: 0 29 | DEC_LAYERS: 6 30 | PRE_NORM: False 31 | # COCO model should not pad image 32 | SIZE_DIVISIBILITY: 0 33 | TEST: 34 | PANOPTIC_ON: True 35 | OVERLAP_THRESHOLD: 0.8 36 | OBJECT_MASK_THRESHOLD: 0.8 37 | -------------------------------------------------------------------------------- /configs/coco-panoptic/swin/maskformer_panoptic_swin_base_IN21k_384_bs64_554k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_panoptic_R50_bs64_554k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [2, 2, 18, 2] 8 | NUM_HEADS: [4, 8, 16, 32] 9 | WINDOW_SIZE: 12 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | PRETRAIN_IMG_SIZE: 384 14 | WEIGHTS: "swin_base_patch4_window12_384_22k.pkl" 15 | PIXEL_MEAN: [123.675, 116.280, 103.530] 16 | PIXEL_STD: [58.395, 57.120, 57.375] 17 | SEM_SEG_HEAD: 18 | PIXEL_DECODER_NAME: "BasePixelDecoder" 19 | MASK_FORMER: 20 | TRANSFORMER_IN_FEATURE: "res5" 21 | ENFORCE_INPUT_PROJ: True 22 | TEST: 23 | PANOPTIC_ON: True 24 | OVERLAP_THRESHOLD: 0.8 25 | OBJECT_MASK_THRESHOLD: 0.8 26 | SOLVER: 27 | BASE_LR: 0.00006 28 | WARMUP_FACTOR: 1e-6 29 | WARMUP_ITERS: 1500 30 | WEIGHT_DECAY: 0.01 31 | WEIGHT_DECAY_NORM: 0.0 32 | WEIGHT_DECAY_EMBED: 0.0 33 | BACKBONE_MULTIPLIER: 1.0 -------------------------------------------------------------------------------- /configs/coco-panoptic/swin/maskformer_panoptic_swin_large_IN21k_384_bs64_554k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_panoptic_R50_bs64_554k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [2, 2, 18, 2] 8 | NUM_HEADS: [6, 12, 24, 48] 9 | WINDOW_SIZE: 12 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | PRETRAIN_IMG_SIZE: 384 14 | WEIGHTS: "swin_large_patch4_window12_384_22k.pkl" 15 | PIXEL_MEAN: [123.675, 116.280, 103.530] 16 | PIXEL_STD: [58.395, 57.120, 57.375] 17 | SEM_SEG_HEAD: 18 | PIXEL_DECODER_NAME: "BasePixelDecoder" 19 | MASK_FORMER: 20 | TRANSFORMER_IN_FEATURE: "res5" 21 | ENFORCE_INPUT_PROJ: True 22 | TEST: 23 | PANOPTIC_ON: True 24 | OVERLAP_THRESHOLD: 0.8 25 | OBJECT_MASK_THRESHOLD: 0.8 26 | SOLVER: 27 | BASE_LR: 0.00006 28 | WARMUP_FACTOR: 1e-6 29 | WARMUP_ITERS: 1500 30 | WEIGHT_DECAY: 0.01 31 | WEIGHT_DECAY_NORM: 0.0 32 | WEIGHT_DECAY_EMBED: 0.0 33 | BACKBONE_MULTIPLIER: 1.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 36 | MAX_SIZE_TRAIN: 1000 37 | CROP: 38 | ENABLED: True 39 | TYPE: "absolute_range" 40 | SIZE: (384, 600) 41 | FORMAT: "RGB" 42 | -------------------------------------------------------------------------------- /configs/coco-panoptic/swin/maskformer_panoptic_swin_small_bs64_554k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_panoptic_R50_bs64_554k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [2, 2, 18, 2] 8 | NUM_HEADS: [3, 6, 12, 24] 9 | WINDOW_SIZE: 7 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | WEIGHTS: "swin_small_patch4_window7_224.pkl" 14 | PIXEL_MEAN: [123.675, 116.280, 103.530] 15 | PIXEL_STD: [58.395, 57.120, 57.375] 16 | SEM_SEG_HEAD: 17 | PIXEL_DECODER_NAME: "BasePixelDecoder" 18 | MASK_FORMER: 19 | TRANSFORMER_IN_FEATURE: "res5" 20 | ENFORCE_INPUT_PROJ: True 21 | TEST: 22 | PANOPTIC_ON: True 23 | OVERLAP_THRESHOLD: 0.8 24 | OBJECT_MASK_THRESHOLD: 0.8 25 | SOLVER: 26 | BASE_LR: 0.00006 27 | WARMUP_FACTOR: 1e-6 28 | WARMUP_ITERS: 1500 29 | WEIGHT_DECAY: 0.01 30 | WEIGHT_DECAY_NORM: 0.0 31 | WEIGHT_DECAY_EMBED: 0.0 32 | BACKBONE_MULTIPLIER: 1.0 33 | -------------------------------------------------------------------------------- /configs/coco-panoptic/swin/maskformer_panoptic_swin_tiny_bs64_554k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ../maskformer_panoptic_R50_bs64_554k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "D2SwinTransformer" 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [2, 2, 6, 2] 8 | NUM_HEADS: [3, 6, 12, 24] 9 | WINDOW_SIZE: 7 10 | APE: False 11 | DROP_PATH_RATE: 0.3 12 | PATCH_NORM: True 13 | WEIGHTS: "swin_tiny_patch4_window7_224.pkl" 14 | PIXEL_MEAN: [123.675, 116.280, 103.530] 15 | PIXEL_STD: [58.395, 57.120, 57.375] 16 | SEM_SEG_HEAD: 17 | PIXEL_DECODER_NAME: "BasePixelDecoder" 18 | MASK_FORMER: 19 | TRANSFORMER_IN_FEATURE: "res5" 20 | ENFORCE_INPUT_PROJ: True 21 | TEST: 22 | PANOPTIC_ON: True 23 | OVERLAP_THRESHOLD: 0.8 24 | OBJECT_MASK_THRESHOLD: 0.8 25 | SOLVER: 26 | BASE_LR: 0.00006 27 | WARMUP_FACTOR: 1e-6 28 | WARMUP_ITERS: 1500 29 | WEIGHT_DECAY: 0.01 30 | WEIGHT_DECAY_NORM: 0.0 31 | WEIGHT_DECAY_EMBED: 0.0 32 | BACKBONE_MULTIPLIER: 1.0 33 | -------------------------------------------------------------------------------- /configs/coco-stuff-10k-171/Base-COCOStuff10K-171.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("coco_2017_train_stuff_10k_sem_seg",) 18 | TEST: ("coco_2017_test_stuff_10k_sem_seg",) 19 | SOLVER: 20 | IMS_PER_BATCH: 32 21 | BASE_LR: 0.0001 22 | MAX_ITER: 60000 23 | WARMUP_FACTOR: 1.0 24 | WARMUP_ITERS: 0 25 | WEIGHT_DECAY: 0.0001 26 | OPTIMIZER: "ADAMW" 27 | LR_SCHEDULER_NAME: "WarmupPolyLR" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 16)]"] 36 | MIN_SIZE_TRAIN_SAMPLING: "choice" 37 | MIN_SIZE_TEST: 640 38 | MAX_SIZE_TRAIN: 2560 39 | MAX_SIZE_TEST: 2560 40 | CROP: 41 | ENABLED: True 42 | TYPE: "absolute" 43 | SIZE: (640, 640) 44 | SINGLE_CATEGORY_MAX_AREA: 1.0 45 | COLOR_AUG_SSD: True 46 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 47 | FORMAT: "RGB" 48 | DATASET_MAPPER_NAME: "mask_former_semantic" 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | AUG: 52 | ENABLED: False 53 | MIN_SIZES: [320, 480, 640, 800, 960, 1120] 54 | MAX_SIZE: 4480 55 | FLIP: True 56 | DATALOADER: 57 | FILTER_EMPTY_ANNOTATIONS: True 58 | NUM_WORKERS: 4 59 | VERSION: 2 60 | -------------------------------------------------------------------------------- /configs/coco-stuff-10k-171/maskformer_R101_bs32_60k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R50_bs32_60k.yaml 2 | MODEL: 3 | WEIGHTS: "R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | STEM_TYPE: "basic" # not used 7 | STEM_OUT_CHANNELS: 64 8 | STRIDE_IN_1X1: False 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | # NORM: "SyncBN" 11 | RES5_MULTI_GRID: [1, 1, 1] # not used 12 | -------------------------------------------------------------------------------- /configs/coco-stuff-10k-171/maskformer_R101c_bs32_60k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: maskformer_R50_bs32_60k.yaml 2 | MODEL: 3 | BACKBONE: 4 | NAME: "build_resnet_deeplab_backbone" 5 | WEIGHTS: "detectron2://DeepLab/R-103.pkl" 6 | RESNETS: 7 | DEPTH: 101 8 | STEM_TYPE: "deeplab" 9 | STEM_OUT_CHANNELS: 128 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | # NORM: "SyncBN" 13 | RES5_MULTI_GRID: [1, 2, 4] 14 | -------------------------------------------------------------------------------- /configs/coco-stuff-10k-171/maskformer_R50_bs32_60k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-COCOStuff10K-171.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "MaskFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 171 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | NO_OBJECT_WEIGHT: 0.1 18 | DICE_WEIGHT: 1.0 19 | MASK_WEIGHT: 20.0 20 | HIDDEN_DIM: 256 21 | NUM_OBJECT_QUERIES: 100 22 | NHEADS: 8 23 | DROPOUT: 0.1 24 | DIM_FEEDFORWARD: 2048 25 | ENC_LAYERS: 0 26 | DEC_LAYERS: 6 27 | PRE_NORM: False 28 | -------------------------------------------------------------------------------- /configs/coco-stuff-10k-171/per_pixel_baseline_R50_bs32_60k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-COCOStuff10K-171.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | SEM_SEG_HEAD: 5 | NAME: "PerPixelBaselineHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 171 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | -------------------------------------------------------------------------------- /configs/coco-stuff-10k-171/per_pixel_baseline_plus_R50_bs32_60k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-COCOStuff10K-171.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | SEM_SEG_HEAD: 5 | NAME: "PerPixelBaselinePlusHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 255 8 | NUM_CLASSES: 171 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | HIDDEN_DIM: 256 18 | NUM_OBJECT_QUERIES: 171 # remember to set this to NUM_CLASSES 19 | NHEADS: 8 20 | DROPOUT: 0.1 21 | DIM_FEEDFORWARD: 2048 22 | ENC_LAYERS: 0 23 | DEC_LAYERS: 6 24 | PRE_NORM: False 25 | -------------------------------------------------------------------------------- /configs/mapillary-vistas-65/Base-MapillaryVistas-65.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | RESNETS: 9 | DEPTH: 50 10 | STEM_TYPE: "basic" # not used 11 | STEM_OUT_CHANNELS: 64 12 | STRIDE_IN_1X1: False 13 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 14 | # NORM: "SyncBN" 15 | RES5_MULTI_GRID: [1, 1, 1] # not used 16 | DATASETS: 17 | TRAIN: ("mapillary_vistas_sem_seg_train",) 18 | TEST: ("mapillary_vistas_sem_seg_val",) 19 | SOLVER: 20 | IMS_PER_BATCH: 16 21 | BASE_LR: 0.0001 22 | MAX_ITER: 300000 23 | WARMUP_FACTOR: 1.0 24 | WARMUP_ITERS: 0 25 | WEIGHT_DECAY: 0.0001 26 | OPTIMIZER: "ADAMW" 27 | LR_SCHEDULER_NAME: "WarmupPolyLR" 28 | BACKBONE_MULTIPLIER: 0.1 29 | CLIP_GRADIENTS: 30 | ENABLED: True 31 | CLIP_TYPE: "full_model" 32 | CLIP_VALUE: 0.01 33 | NORM_TYPE: 2.0 34 | INPUT: 35 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 2048) for x in range(5, 21)]"] 36 | MIN_SIZE_TRAIN_SAMPLING: "choice" 37 | MIN_SIZE_TEST: 2048 38 | MAX_SIZE_TRAIN: 8192 39 | MAX_SIZE_TEST: 2048 40 | CROP: 41 | ENABLED: True 42 | TYPE: "absolute" 43 | SIZE: (1280, 1280) 44 | SINGLE_CATEGORY_MAX_AREA: 1.0 45 | COLOR_AUG_SSD: True 46 | SIZE_DIVISIBILITY: 1280 # used in dataset mapper 47 | FORMAT: "RGB" 48 | DATASET_MAPPER_NAME: "mask_former_semantic" 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | DATALOADER: 52 | FILTER_EMPTY_ANNOTATIONS: True 53 | NUM_WORKERS: 10 54 | VERSION: 2 55 | -------------------------------------------------------------------------------- /configs/mapillary-vistas-65/maskformer_R50_bs16_300k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-MapillaryVistas-65.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "MaskFormer" 4 | SEM_SEG_HEAD: 5 | NAME: "MaskFormerHead" 6 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 7 | IGNORE_VALUE: 65 8 | NUM_CLASSES: 65 9 | COMMON_STRIDE: 4 # not used, hard-coded 10 | LOSS_WEIGHT: 1.0 11 | CONVS_DIM: 256 12 | MASK_DIM: 256 13 | NORM: "GN" 14 | MASK_FORMER: 15 | TRANSFORMER_IN_FEATURE: "res5" 16 | DEEP_SUPERVISION: True 17 | NO_OBJECT_WEIGHT: 0.1 18 | DICE_WEIGHT: 1.0 19 | MASK_WEIGHT: 20.0 20 | HIDDEN_DIM: 256 21 | NUM_OBJECT_QUERIES: 100 22 | NHEADS: 8 23 | DROPOUT: 0.1 24 | DIM_FEEDFORWARD: 2048 25 | ENC_LAYERS: 0 26 | DEC_LAYERS: 6 27 | PRE_NORM: False 28 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Prepare Datasets for MaskFormer 2 | 3 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) 4 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). 5 | This document explains how to setup the builtin datasets so they can be used by the above APIs. 6 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, 7 | and how to add new datasets to them. 8 | 9 | MaskFormer has builtin support for a few datasets. 10 | The datasets are assumed to exist in a directory specified by the environment variable 11 | `DETECTRON2_DATASETS`. 12 | Under this directory, detectron2 will look for datasets in the structure described below, if needed. 13 | ``` 14 | $DETECTRON2_DATASETS/ 15 | ADEChallengeData2016/ 16 | ADE20K_2021_17_01/ 17 | coco/ 18 | cityscapes/ 19 | mapillary_vistas/ 20 | ``` 21 | 22 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. 23 | If left unset, the default is `./datasets` relative to your current working directory. 24 | 25 | The [model zoo](https://github.com/facebookresearch/MaskFormer/blob/master/MODEL_ZOO.md) 26 | contains configs and models that use these builtin datasets. 27 | 28 | ## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/): 29 | ``` 30 | ADEChallengeData2016/ 31 | annotations/ 32 | annotations_detectron2/ 33 | images/ 34 | objectInfo150.txt 35 | ``` 36 | The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`. 37 | 38 | ## Expected dataset structure for ADE20K panoptic segmentation: 39 | ``` 40 | ADEChallengeData2016/ 41 | images/ 42 | annotations/ 43 | objectInfo150.txt 44 | # download instance annotation 45 | annotations_instance/ 46 | # generated by prepare_ade20k_sem_seg.py 47 | annotations_detectron2/ 48 | # below are generated by prepare_ade20k_panoptic_annotations.py 49 | ade20k_panoptic_train.json 50 | ade20k_panoptic_train/ 51 | ade20k_panoptic_val.json 52 | ade20k_panoptic_val/ 53 | ``` 54 | Install panopticapi by: 55 | ```bash 56 | pip install git+https://github.com/cocodataset/panopticapi.git 57 | ``` 58 | 59 | Download the instance annotation from http://sceneparsing.csail.mit.edu/: 60 | ```bash 61 | wget http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar 62 | ``` 63 | 64 | Then, run `python datasets/prepare_ade20k_pan_seg.py`, to combine semantic and instance annotations for panoptic annotations. 65 | 66 | ## Expected dataset structure for [ADE20k-Full](https://groups.csail.mit.edu/vision/datasets/ADE20K/): 67 | ``` 68 | ADE20K_2021_17_01/ 69 | images/ 70 | images_detectron2/ 71 | annotations_detectron2/ 72 | index_ade20k.pkl 73 | objects.txt 74 | ``` 75 | The directories `images_detectron2` and `annotations_detectron2` are generated by running `python datasets/prepare_ade20k_full_sem_seg.py`. 76 | 77 | ## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/): 78 | ``` 79 | cityscapes/ 80 | gtFine/ 81 | train/ 82 | aachen/ 83 | color.png, instanceIds.png, labelIds.png, polygons.json, 84 | labelTrainIds.png 85 | ... 86 | val/ 87 | test/ 88 | # below are generated Cityscapes panoptic annotation 89 | cityscapes_panoptic_train.json 90 | cityscapes_panoptic_train/ 91 | cityscapes_panoptic_val.json 92 | cityscapes_panoptic_val/ 93 | cityscapes_panoptic_test.json 94 | cityscapes_panoptic_test/ 95 | leftImg8bit/ 96 | train/ 97 | val/ 98 | test/ 99 | ``` 100 | Install cityscapes scripts by: 101 | ``` 102 | pip install git+https://github.com/mcordts/cityscapesScripts.git 103 | ``` 104 | 105 | Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with: 106 | ``` 107 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py 108 | ``` 109 | These files are not needed for instance segmentation. 110 | 111 | Note: to generate Cityscapes panoptic dataset, run cityscapesescript with: 112 | ``` 113 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py 114 | ``` 115 | These files are not needed for semantic and instance segmentation. 116 | 117 | ## Expected dataset structure for [COCO-Stuff-10K](https://github.com/nightrome/cocostuff10k): 118 | 119 | ``` 120 | coco/ 121 | coco_stuff_10k/ 122 | annotations/ 123 | COCO_train2014_000000000077.mat 124 | ... 125 | imageLists/ 126 | all.txt 127 | test.txt 128 | train.txt 129 | images/ 130 | COCO_train2014_000000000077.jpg 131 | ... 132 | # below are generated by prepare_coco_stuff_10k_v1.0_sem_seg.py 133 | annotations_detectron2/ 134 | train/ 135 | test/ 136 | images_detectron2/ 137 | train/ 138 | test/ 139 | ``` 140 | 141 | Get the COCO-Stuff-10k **v1.0** annotation from https://github.com/nightrome/cocostuff10k. 142 | ```bash 143 | wget http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/cocostuff-10k-v1.0.zip 144 | ``` 145 | Unzip `cocostuff-10k-v1.0.zip` and put `annotations`, `imageLists` and `images` to the correct location listed above. 146 | 147 | Generate COCO-Stuff-10k annotation by `python datasets/prepare_coco_stuff_10k_v1.0_sem_seg.py` 148 | 149 | ## Expected dataset structure for [Mapillary Vistas](https://www.mapillary.com/dataset/vistas): 150 | ``` 151 | mapillary_vistas/ 152 | training/ 153 | images/ 154 | instances/ 155 | labels/ 156 | panoptic/ 157 | validation/ 158 | images/ 159 | instances/ 160 | labels/ 161 | panoptic/ 162 | ``` 163 | 164 | No preprocessing is needed for Mapillary Vistas. 165 | -------------------------------------------------------------------------------- /datasets/ade20k_instance_catid_mapping.txt: -------------------------------------------------------------------------------- 1 | Instacne100 SceneParse150 FullADE20K 2 | 1 8 165 3 | 2 9 3055 4 | 3 11 350 5 | 4 13 1831 6 | 5 15 774 7 | 5 15 783 8 | 6 16 2684 9 | 7 19 687 10 | 8 20 471 11 | 9 21 401 12 | 10 23 1735 13 | 11 24 2473 14 | 12 25 2329 15 | 13 28 1564 16 | 14 31 57 17 | 15 32 2272 18 | 16 33 907 19 | 17 34 724 20 | 18 36 2985 21 | 18 36 533 22 | 19 37 1395 23 | 20 38 155 24 | 21 39 2053 25 | 22 40 689 26 | 23 42 266 27 | 24 43 581 28 | 25 44 2380 29 | 26 45 491 30 | 27 46 627 31 | 28 48 2388 32 | 29 50 943 33 | 30 51 2096 34 | 31 54 2530 35 | 32 56 420 36 | 33 57 1948 37 | 34 58 1869 38 | 35 59 2251 39 | 36 63 239 40 | 37 65 571 41 | 38 66 2793 42 | 39 67 978 43 | 40 68 236 44 | 41 70 181 45 | 42 71 629 46 | 43 72 2598 47 | 44 73 1744 48 | 45 74 1374 49 | 46 75 591 50 | 47 76 2679 51 | 48 77 223 52 | 49 79 47 53 | 50 81 327 54 | 51 82 2821 55 | 52 83 1451 56 | 53 84 2880 57 | 54 86 480 58 | 55 87 77 59 | 56 88 2616 60 | 57 89 246 61 | 57 89 247 62 | 58 90 2733 63 | 59 91 14 64 | 60 93 38 65 | 61 94 1936 66 | 62 96 120 67 | 63 98 1702 68 | 64 99 249 69 | 65 103 2928 70 | 66 104 2337 71 | 67 105 1023 72 | 68 108 2989 73 | 69 109 1930 74 | 70 111 2586 75 | 71 112 131 76 | 72 113 146 77 | 73 116 95 78 | 74 117 1563 79 | 75 119 1708 80 | 76 120 103 81 | 77 121 1002 82 | 78 122 2569 83 | 79 124 2833 84 | 80 125 1551 85 | 81 126 1981 86 | 82 127 29 87 | 83 128 187 88 | 84 130 747 89 | 85 131 2254 90 | 86 133 2262 91 | 87 134 1260 92 | 88 135 2243 93 | 89 136 2932 94 | 90 137 2836 95 | 91 138 2850 96 | 92 139 64 97 | 93 140 894 98 | 94 143 1919 99 | 95 144 1583 100 | 96 145 318 101 | 97 147 2046 102 | 98 148 1098 103 | 99 149 530 104 | 100 150 954 105 | -------------------------------------------------------------------------------- /datasets/prepare_ade20k_sem_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | import os 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import tqdm 9 | from PIL import Image 10 | 11 | 12 | def convert(input, output): 13 | img = np.asarray(Image.open(input)) 14 | assert img.dtype == np.uint8 15 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 16 | Image.fromarray(img).save(output) 17 | 18 | 19 | if __name__ == "__main__": 20 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016" 21 | for name in ["training", "validation"]: 22 | annotation_dir = dataset_dir / "annotations" / name 23 | output_dir = dataset_dir / "annotations_detectron2" / name 24 | output_dir.mkdir(parents=True, exist_ok=True) 25 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 26 | output_file = output_dir / file.name 27 | convert(file, output_file) 28 | -------------------------------------------------------------------------------- /datasets/prepare_coco_stuff_10k_v1.0_sem_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | import os 5 | from pathlib import Path 6 | from shutil import copyfile 7 | 8 | import h5py 9 | import numpy as np 10 | import tqdm 11 | from PIL import Image 12 | 13 | if __name__ == "__main__": 14 | dataset_dir = os.path.join( 15 | os.getenv("DETECTRON2_DATASETS", "datasets"), "coco", "coco_stuff_10k" 16 | ) 17 | for s in ["test", "train"]: 18 | image_list_file = os.path.join(dataset_dir, "imageLists", f"{s}.txt") 19 | with open(image_list_file, "r") as f: 20 | image_list = f.readlines() 21 | 22 | image_list = [f.strip() for f in image_list] 23 | 24 | image_dir = os.path.join(dataset_dir, "images_detectron2", s) 25 | Path(image_dir).mkdir(parents=True, exist_ok=True) 26 | annotation_dir = os.path.join(dataset_dir, "annotations_detectron2", s) 27 | Path(annotation_dir).mkdir(parents=True, exist_ok=True) 28 | 29 | for fname in tqdm.tqdm(image_list): 30 | copyfile( 31 | os.path.join(dataset_dir, "images", fname + ".jpg"), 32 | os.path.join(image_dir, fname + ".jpg"), 33 | ) 34 | 35 | img = np.asarray(Image.open(os.path.join(image_dir, fname + ".jpg"))) 36 | 37 | matfile = h5py.File(os.path.join(dataset_dir, "annotations", fname + ".mat")) 38 | S = np.array(matfile["S"]).astype(np.uint8) 39 | S = np.transpose(S) 40 | S = S - 2 # 1 (ignore) becomes 255. others are shifted by 2 41 | 42 | assert S.shape == img.shape[:2], "{} vs {}".format(S.shape, img.shape) 43 | 44 | Image.fromarray(S).save(os.path.join(annotation_dir, fname + ".png")) 45 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | ## MaskFormer Demo 2 | 3 | We provide a command line tool to run a simple demo of builtin configs. 4 | The usage is explained in [GETTING_STARTED.md](../GETTING_STARTED.md). 5 | -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py 3 | import argparse 4 | import glob 5 | import multiprocessing as mp 6 | import os 7 | 8 | # fmt: off 9 | import sys 10 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 11 | # fmt: on 12 | 13 | import tempfile 14 | import time 15 | import warnings 16 | 17 | import cv2 18 | import numpy as np 19 | import tqdm 20 | 21 | from detectron2.config import get_cfg 22 | from detectron2.data.detection_utils import read_image 23 | from detectron2.projects.deeplab import add_deeplab_config 24 | from detectron2.utils.logger import setup_logger 25 | 26 | from mask_former import add_mask_former_config 27 | from predictor import VisualizationDemo 28 | 29 | 30 | # constants 31 | WINDOW_NAME = "MaskFormer demo" 32 | 33 | 34 | def setup_cfg(args): 35 | # load config from file and command-line arguments 36 | cfg = get_cfg() 37 | add_deeplab_config(cfg) 38 | add_mask_former_config(cfg) 39 | cfg.merge_from_file(args.config_file) 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | return cfg 43 | 44 | 45 | def get_parser(): 46 | parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") 47 | parser.add_argument( 48 | "--config-file", 49 | default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml", 50 | metavar="FILE", 51 | help="path to config file", 52 | ) 53 | parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") 54 | parser.add_argument("--video-input", help="Path to video file.") 55 | parser.add_argument( 56 | "--input", 57 | nargs="+", 58 | help="A list of space separated input images; " 59 | "or a single glob pattern such as 'directory/*.jpg'", 60 | ) 61 | parser.add_argument( 62 | "--output", 63 | help="A file or directory to save output visualizations. " 64 | "If not given, will show output in an OpenCV window.", 65 | ) 66 | 67 | parser.add_argument( 68 | "--confidence-threshold", 69 | type=float, 70 | default=0.5, 71 | help="Minimum score for instance predictions to be shown", 72 | ) 73 | parser.add_argument( 74 | "--opts", 75 | help="Modify config options using the command-line 'KEY VALUE' pairs", 76 | default=[], 77 | nargs=argparse.REMAINDER, 78 | ) 79 | return parser 80 | 81 | 82 | def test_opencv_video_format(codec, file_ext): 83 | with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: 84 | filename = os.path.join(dir, "test_file" + file_ext) 85 | writer = cv2.VideoWriter( 86 | filename=filename, 87 | fourcc=cv2.VideoWriter_fourcc(*codec), 88 | fps=float(30), 89 | frameSize=(10, 10), 90 | isColor=True, 91 | ) 92 | [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] 93 | writer.release() 94 | if os.path.isfile(filename): 95 | return True 96 | return False 97 | 98 | 99 | if __name__ == "__main__": 100 | mp.set_start_method("spawn", force=True) 101 | args = get_parser().parse_args() 102 | setup_logger(name="fvcore") 103 | logger = setup_logger() 104 | logger.info("Arguments: " + str(args)) 105 | 106 | cfg = setup_cfg(args) 107 | 108 | demo = VisualizationDemo(cfg) 109 | 110 | if args.input: 111 | if len(args.input) == 1: 112 | args.input = glob.glob(os.path.expanduser(args.input[0])) 113 | assert args.input, "The input path(s) was not found" 114 | for path in tqdm.tqdm(args.input, disable=not args.output): 115 | # use PIL, to be consistent with evaluation 116 | img = read_image(path, format="BGR") 117 | start_time = time.time() 118 | predictions, visualized_output = demo.run_on_image(img) 119 | logger.info( 120 | "{}: {} in {:.2f}s".format( 121 | path, 122 | "detected {} instances".format(len(predictions["instances"])) 123 | if "instances" in predictions 124 | else "finished", 125 | time.time() - start_time, 126 | ) 127 | ) 128 | 129 | if args.output: 130 | if os.path.isdir(args.output): 131 | assert os.path.isdir(args.output), args.output 132 | out_filename = os.path.join(args.output, os.path.basename(path)) 133 | else: 134 | assert len(args.input) == 1, "Please specify a directory with args.output" 135 | out_filename = args.output 136 | visualized_output.save(out_filename) 137 | else: 138 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 139 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) 140 | if cv2.waitKey(0) == 27: 141 | break # esc to quit 142 | elif args.webcam: 143 | assert args.input is None, "Cannot have both --input and --webcam!" 144 | assert args.output is None, "output not yet supported with --webcam!" 145 | cam = cv2.VideoCapture(0) 146 | for vis in tqdm.tqdm(demo.run_on_video(cam)): 147 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 148 | cv2.imshow(WINDOW_NAME, vis) 149 | if cv2.waitKey(1) == 27: 150 | break # esc to quit 151 | cam.release() 152 | cv2.destroyAllWindows() 153 | elif args.video_input: 154 | video = cv2.VideoCapture(args.video_input) 155 | width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 156 | height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 157 | frames_per_second = video.get(cv2.CAP_PROP_FPS) 158 | num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 159 | basename = os.path.basename(args.video_input) 160 | codec, file_ext = ( 161 | ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") 162 | ) 163 | if codec == ".mp4v": 164 | warnings.warn("x264 codec not available, switching to mp4v") 165 | if args.output: 166 | if os.path.isdir(args.output): 167 | output_fname = os.path.join(args.output, basename) 168 | output_fname = os.path.splitext(output_fname)[0] + file_ext 169 | else: 170 | output_fname = args.output 171 | assert not os.path.isfile(output_fname), output_fname 172 | output_file = cv2.VideoWriter( 173 | filename=output_fname, 174 | # some installation of opencv may not support x264 (due to its license), 175 | # you can try other format (e.g. MPEG) 176 | fourcc=cv2.VideoWriter_fourcc(*codec), 177 | fps=float(frames_per_second), 178 | frameSize=(width, height), 179 | isColor=True, 180 | ) 181 | assert os.path.isfile(args.video_input) 182 | for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): 183 | if args.output: 184 | output_file.write(vis_frame) 185 | else: 186 | cv2.namedWindow(basename, cv2.WINDOW_NORMAL) 187 | cv2.imshow(basename, vis_frame) 188 | if cv2.waitKey(1) == 27: 189 | break # esc to quit 190 | video.release() 191 | if args.output: 192 | output_file.release() 193 | else: 194 | cv2.destroyAllWindows() 195 | -------------------------------------------------------------------------------- /demo/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py 3 | import atexit 4 | import bisect 5 | import multiprocessing as mp 6 | from collections import deque 7 | 8 | import cv2 9 | import torch 10 | 11 | from detectron2.data import MetadataCatalog 12 | from detectron2.engine.defaults import DefaultPredictor 13 | from detectron2.utils.video_visualizer import VideoVisualizer 14 | from detectron2.utils.visualizer import ColorMode, Visualizer 15 | 16 | 17 | class VisualizationDemo(object): 18 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): 19 | """ 20 | Args: 21 | cfg (CfgNode): 22 | instance_mode (ColorMode): 23 | parallel (bool): whether to run the model in different processes from visualization. 24 | Useful since the visualization logic can be slow. 25 | """ 26 | self.metadata = MetadataCatalog.get( 27 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" 28 | ) 29 | self.cpu_device = torch.device("cpu") 30 | self.instance_mode = instance_mode 31 | 32 | self.parallel = parallel 33 | if parallel: 34 | num_gpu = torch.cuda.device_count() 35 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) 36 | else: 37 | self.predictor = DefaultPredictor(cfg) 38 | 39 | def run_on_image(self, image): 40 | """ 41 | Args: 42 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 43 | This is the format used by OpenCV. 44 | Returns: 45 | predictions (dict): the output of the model. 46 | vis_output (VisImage): the visualized image output. 47 | """ 48 | vis_output = None 49 | predictions = self.predictor(image) 50 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 51 | image = image[:, :, ::-1] 52 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) 53 | if "panoptic_seg" in predictions: 54 | panoptic_seg, segments_info = predictions["panoptic_seg"] 55 | vis_output = visualizer.draw_panoptic_seg_predictions( 56 | panoptic_seg.to(self.cpu_device), segments_info 57 | ) 58 | else: 59 | if "sem_seg" in predictions: 60 | vis_output = visualizer.draw_sem_seg( 61 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 62 | ) 63 | if "instances" in predictions: 64 | instances = predictions["instances"].to(self.cpu_device) 65 | vis_output = visualizer.draw_instance_predictions(predictions=instances) 66 | 67 | return predictions, vis_output 68 | 69 | def _frame_from_video(self, video): 70 | while video.isOpened(): 71 | success, frame = video.read() 72 | if success: 73 | yield frame 74 | else: 75 | break 76 | 77 | def run_on_video(self, video): 78 | """ 79 | Visualizes predictions on frames of the input video. 80 | Args: 81 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be 82 | either a webcam or a video file. 83 | Yields: 84 | ndarray: BGR visualizations of each video frame. 85 | """ 86 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) 87 | 88 | def process_predictions(frame, predictions): 89 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 90 | if "panoptic_seg" in predictions: 91 | panoptic_seg, segments_info = predictions["panoptic_seg"] 92 | vis_frame = video_visualizer.draw_panoptic_seg_predictions( 93 | frame, panoptic_seg.to(self.cpu_device), segments_info 94 | ) 95 | elif "instances" in predictions: 96 | predictions = predictions["instances"].to(self.cpu_device) 97 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) 98 | elif "sem_seg" in predictions: 99 | vis_frame = video_visualizer.draw_sem_seg( 100 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 101 | ) 102 | 103 | # Converts Matplotlib RGB format to OpenCV BGR format 104 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) 105 | return vis_frame 106 | 107 | frame_gen = self._frame_from_video(video) 108 | if self.parallel: 109 | buffer_size = self.predictor.default_buffer_size 110 | 111 | frame_data = deque() 112 | 113 | for cnt, frame in enumerate(frame_gen): 114 | frame_data.append(frame) 115 | self.predictor.put(frame) 116 | 117 | if cnt >= buffer_size: 118 | frame = frame_data.popleft() 119 | predictions = self.predictor.get() 120 | yield process_predictions(frame, predictions) 121 | 122 | while len(frame_data): 123 | frame = frame_data.popleft() 124 | predictions = self.predictor.get() 125 | yield process_predictions(frame, predictions) 126 | else: 127 | for frame in frame_gen: 128 | yield process_predictions(frame, self.predictor(frame)) 129 | 130 | 131 | class AsyncPredictor: 132 | """ 133 | A predictor that runs the model asynchronously, possibly on >1 GPUs. 134 | Because rendering the visualization takes considerably amount of time, 135 | this helps improve throughput a little bit when rendering videos. 136 | """ 137 | 138 | class _StopToken: 139 | pass 140 | 141 | class _PredictWorker(mp.Process): 142 | def __init__(self, cfg, task_queue, result_queue): 143 | self.cfg = cfg 144 | self.task_queue = task_queue 145 | self.result_queue = result_queue 146 | super().__init__() 147 | 148 | def run(self): 149 | predictor = DefaultPredictor(self.cfg) 150 | 151 | while True: 152 | task = self.task_queue.get() 153 | if isinstance(task, AsyncPredictor._StopToken): 154 | break 155 | idx, data = task 156 | result = predictor(data) 157 | self.result_queue.put((idx, result)) 158 | 159 | def __init__(self, cfg, num_gpus: int = 1): 160 | """ 161 | Args: 162 | cfg (CfgNode): 163 | num_gpus (int): if 0, will run on CPU 164 | """ 165 | num_workers = max(num_gpus, 1) 166 | self.task_queue = mp.Queue(maxsize=num_workers * 3) 167 | self.result_queue = mp.Queue(maxsize=num_workers * 3) 168 | self.procs = [] 169 | for gpuid in range(max(num_gpus, 1)): 170 | cfg = cfg.clone() 171 | cfg.defrost() 172 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" 173 | self.procs.append( 174 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) 175 | ) 176 | 177 | self.put_idx = 0 178 | self.get_idx = 0 179 | self.result_rank = [] 180 | self.result_data = [] 181 | 182 | for p in self.procs: 183 | p.start() 184 | atexit.register(self.shutdown) 185 | 186 | def put(self, image): 187 | self.put_idx += 1 188 | self.task_queue.put((self.put_idx, image)) 189 | 190 | def get(self): 191 | self.get_idx += 1 # the index needed for this request 192 | if len(self.result_rank) and self.result_rank[0] == self.get_idx: 193 | res = self.result_data[0] 194 | del self.result_data[0], self.result_rank[0] 195 | return res 196 | 197 | while True: 198 | # make sure the results are returned in the correct order 199 | idx, res = self.result_queue.get() 200 | if idx == self.get_idx: 201 | return res 202 | insert = bisect.bisect(self.result_rank, idx) 203 | self.result_rank.insert(insert, idx) 204 | self.result_data.insert(insert, res) 205 | 206 | def __len__(self): 207 | return self.put_idx - self.get_idx 208 | 209 | def __call__(self, image): 210 | self.put(image) 211 | return self.get() 212 | 213 | def shutdown(self): 214 | for _ in self.procs: 215 | self.task_queue.put(AsyncPredictor._StopToken()) 216 | 217 | @property 218 | def default_buffer_size(self): 219 | return len(self.procs) * 5 220 | -------------------------------------------------------------------------------- /mask_former/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import data # register all new datasets 3 | from . import modeling 4 | 5 | # config 6 | from .config import add_mask_former_config 7 | 8 | # dataset loading 9 | from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper 10 | from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import ( 11 | MaskFormerPanopticDatasetMapper, 12 | ) 13 | from .data.dataset_mappers.mask_former_semantic_dataset_mapper import ( 14 | MaskFormerSemanticDatasetMapper, 15 | ) 16 | 17 | # models 18 | from .mask_former_model import MaskFormer 19 | from .test_time_augmentation import SemanticSegmentorWithTTA 20 | -------------------------------------------------------------------------------- /mask_former/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | from detectron2.config import CfgNode as CN 4 | 5 | 6 | def add_mask_former_config(cfg): 7 | """ 8 | Add config for MASK_FORMER. 9 | """ 10 | # data config 11 | # select the dataset mapper 12 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" 13 | # Color augmentation 14 | cfg.INPUT.COLOR_AUG_SSD = False 15 | # We retry random cropping until no single category in semantic segmentation GT occupies more 16 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 17 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 18 | # Pad image and segmentation GT in dataset mapper. 19 | cfg.INPUT.SIZE_DIVISIBILITY = -1 20 | 21 | # solver config 22 | # weight decay on embedding 23 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 24 | # optimizer 25 | cfg.SOLVER.OPTIMIZER = "ADAMW" 26 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 27 | 28 | # mask_former model config 29 | cfg.MODEL.MASK_FORMER = CN() 30 | 31 | # loss 32 | cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True 33 | cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 34 | cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 35 | cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 36 | 37 | # transformer config 38 | cfg.MODEL.MASK_FORMER.NHEADS = 8 39 | cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 40 | cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 41 | cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 42 | cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 43 | cfg.MODEL.MASK_FORMER.PRE_NORM = False 44 | 45 | cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 46 | cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 47 | 48 | cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" 49 | cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False 50 | 51 | # mask_former inference config 52 | cfg.MODEL.MASK_FORMER.TEST = CN() 53 | cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False 54 | cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 55 | cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 56 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False 57 | 58 | # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) 59 | # you can use this config to override 60 | cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 61 | 62 | # pixel decoder config 63 | cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 64 | # adding transformer in pixel decoder 65 | cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 66 | # pixel decoder 67 | cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" 68 | 69 | # swin transformer backbone 70 | cfg.MODEL.SWIN = CN() 71 | cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 72 | cfg.MODEL.SWIN.PATCH_SIZE = 4 73 | cfg.MODEL.SWIN.EMBED_DIM = 96 74 | cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 75 | cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 76 | cfg.MODEL.SWIN.WINDOW_SIZE = 7 77 | cfg.MODEL.SWIN.MLP_RATIO = 4.0 78 | cfg.MODEL.SWIN.QKV_BIAS = True 79 | cfg.MODEL.SWIN.QK_SCALE = None 80 | cfg.MODEL.SWIN.DROP_RATE = 0.0 81 | cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 82 | cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 83 | cfg.MODEL.SWIN.APE = False 84 | cfg.MODEL.SWIN.PATCH_NORM = True 85 | cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] 86 | -------------------------------------------------------------------------------- /mask_former/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import datasets 3 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from detectron2.config import configurable 10 | from detectron2.data import detection_utils as utils 11 | from detectron2.data import transforms as T 12 | from detectron2.data.transforms import TransformGen 13 | from detectron2.structures import BitMasks, Instances 14 | 15 | __all__ = ["DETRPanopticDatasetMapper"] 16 | 17 | 18 | def build_transform_gen(cfg, is_train): 19 | """ 20 | Create a list of :class:`TransformGen` from config. 21 | Returns: 22 | list[TransformGen] 23 | """ 24 | if is_train: 25 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 26 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 27 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 28 | else: 29 | min_size = cfg.INPUT.MIN_SIZE_TEST 30 | max_size = cfg.INPUT.MAX_SIZE_TEST 31 | sample_style = "choice" 32 | if sample_style == "range": 33 | assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( 34 | len(min_size) 35 | ) 36 | 37 | logger = logging.getLogger(__name__) 38 | tfm_gens = [] 39 | if is_train: 40 | tfm_gens.append(T.RandomFlip()) 41 | tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 42 | if is_train: 43 | logger.info("TransformGens used in training: " + str(tfm_gens)) 44 | return tfm_gens 45 | 46 | 47 | # This is specifically designed for the COCO dataset. 48 | class DETRPanopticDatasetMapper: 49 | """ 50 | A callable which takes a dataset dict in Detectron2 Dataset format, 51 | and map it into a format used by MaskFormer. 52 | 53 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 54 | 55 | The callable currently does the following: 56 | 57 | 1. Read the image from "file_name" 58 | 2. Applies geometric transforms to the image and annotation 59 | 3. Find and applies suitable cropping to the image and annotation 60 | 4. Prepare image and annotation to Tensors 61 | """ 62 | 63 | @configurable 64 | def __init__( 65 | self, 66 | is_train=True, 67 | *, 68 | crop_gen, 69 | tfm_gens, 70 | image_format, 71 | ): 72 | """ 73 | NOTE: this interface is experimental. 74 | Args: 75 | is_train: for training or inference 76 | augmentations: a list of augmentations or deterministic transforms to apply 77 | crop_gen: crop augmentation 78 | tfm_gens: data augmentation 79 | image_format: an image format supported by :func:`detection_utils.read_image`. 80 | """ 81 | self.crop_gen = crop_gen 82 | self.tfm_gens = tfm_gens 83 | logging.getLogger(__name__).info( 84 | "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format( 85 | str(self.tfm_gens), str(self.crop_gen) 86 | ) 87 | ) 88 | 89 | self.img_format = image_format 90 | self.is_train = is_train 91 | 92 | @classmethod 93 | def from_config(cls, cfg, is_train=True): 94 | # Build augmentation 95 | if cfg.INPUT.CROP.ENABLED and is_train: 96 | crop_gen = [ 97 | T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), 98 | T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), 99 | ] 100 | else: 101 | crop_gen = None 102 | 103 | tfm_gens = build_transform_gen(cfg, is_train) 104 | 105 | ret = { 106 | "is_train": is_train, 107 | "crop_gen": crop_gen, 108 | "tfm_gens": tfm_gens, 109 | "image_format": cfg.INPUT.FORMAT, 110 | } 111 | return ret 112 | 113 | def __call__(self, dataset_dict): 114 | """ 115 | Args: 116 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 117 | 118 | Returns: 119 | dict: a format that builtin models in detectron2 accept 120 | """ 121 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 122 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 123 | utils.check_image_size(dataset_dict, image) 124 | 125 | if self.crop_gen is None: 126 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 127 | else: 128 | if np.random.rand() > 0.5: 129 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 130 | else: 131 | image, transforms = T.apply_transform_gens( 132 | self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image 133 | ) 134 | 135 | image_shape = image.shape[:2] # h, w 136 | 137 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 138 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 139 | # Therefore it's important to use torch.Tensor. 140 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 141 | 142 | if not self.is_train: 143 | # USER: Modify this if you want to keep them for some reason. 144 | dataset_dict.pop("annotations", None) 145 | return dataset_dict 146 | 147 | if "pan_seg_file_name" in dataset_dict: 148 | pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") 149 | segments_info = dataset_dict["segments_info"] 150 | 151 | # apply the same transformation to panoptic segmentation 152 | pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) 153 | 154 | from panopticapi.utils import rgb2id 155 | 156 | pan_seg_gt = rgb2id(pan_seg_gt) 157 | 158 | instances = Instances(image_shape) 159 | classes = [] 160 | masks = [] 161 | for segment_info in segments_info: 162 | class_id = segment_info["category_id"] 163 | if not segment_info["iscrowd"]: 164 | classes.append(class_id) 165 | masks.append(pan_seg_gt == segment_info["id"]) 166 | 167 | classes = np.array(classes) 168 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 169 | if len(masks) == 0: 170 | # Some image does not have annotation (all ignored) 171 | instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) 172 | else: 173 | masks = BitMasks( 174 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 175 | ) 176 | instances.gt_masks = masks.tensor 177 | 178 | dataset_dict["instances"] = instances 179 | 180 | return dataset_dict 181 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.data import detection_utils as utils 11 | from detectron2.data import transforms as T 12 | from detectron2.structures import BitMasks, Instances 13 | 14 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper 15 | 16 | __all__ = ["MaskFormerPanopticDatasetMapper"] 17 | 18 | 19 | class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper): 20 | """ 21 | A callable which takes a dataset dict in Detectron2 Dataset format, 22 | and map it into a format used by MaskFormer for panoptic segmentation. 23 | 24 | The callable currently does the following: 25 | 26 | 1. Read the image from "file_name" 27 | 2. Applies geometric transforms to the image and annotation 28 | 3. Find and applies suitable cropping to the image and annotation 29 | 4. Prepare image and annotation to Tensors 30 | """ 31 | 32 | @configurable 33 | def __init__( 34 | self, 35 | is_train=True, 36 | *, 37 | augmentations, 38 | image_format, 39 | ignore_label, 40 | size_divisibility, 41 | ): 42 | """ 43 | NOTE: this interface is experimental. 44 | Args: 45 | is_train: for training or inference 46 | augmentations: a list of augmentations or deterministic transforms to apply 47 | image_format: an image format supported by :func:`detection_utils.read_image`. 48 | ignore_label: the label that is ignored to evaluation 49 | size_divisibility: pad image size to be divisible by this value 50 | """ 51 | super().__init__( 52 | is_train, 53 | augmentations=augmentations, 54 | image_format=image_format, 55 | ignore_label=ignore_label, 56 | size_divisibility=size_divisibility, 57 | ) 58 | 59 | def __call__(self, dataset_dict): 60 | """ 61 | Args: 62 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 63 | 64 | Returns: 65 | dict: a format that builtin models in detectron2 accept 66 | """ 67 | assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!" 68 | 69 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 70 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 71 | utils.check_image_size(dataset_dict, image) 72 | 73 | # semantic segmentation 74 | if "sem_seg_file_name" in dataset_dict: 75 | # PyTorch transformation not implemented for uint16, so converting it to double first 76 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double") 77 | else: 78 | sem_seg_gt = None 79 | 80 | # panoptic segmentation 81 | if "pan_seg_file_name" in dataset_dict: 82 | pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") 83 | segments_info = dataset_dict["segments_info"] 84 | else: 85 | pan_seg_gt = None 86 | segments_info = None 87 | 88 | if pan_seg_gt is None: 89 | raise ValueError( 90 | "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format( 91 | dataset_dict["file_name"] 92 | ) 93 | ) 94 | 95 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 96 | aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) 97 | image = aug_input.image 98 | if sem_seg_gt is not None: 99 | sem_seg_gt = aug_input.sem_seg 100 | 101 | # apply the same transformation to panoptic segmentation 102 | pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) 103 | 104 | from panopticapi.utils import rgb2id 105 | 106 | pan_seg_gt = rgb2id(pan_seg_gt) 107 | 108 | # Pad image and segmentation label here! 109 | image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 110 | if sem_seg_gt is not None: 111 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 112 | pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long")) 113 | 114 | if self.size_divisibility > 0: 115 | image_size = (image.shape[-2], image.shape[-1]) 116 | padding_size = [ 117 | 0, 118 | self.size_divisibility - image_size[1], 119 | 0, 120 | self.size_divisibility - image_size[0], 121 | ] 122 | image = F.pad(image, padding_size, value=128).contiguous() 123 | if sem_seg_gt is not None: 124 | sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() 125 | pan_seg_gt = F.pad( 126 | pan_seg_gt, padding_size, value=0 127 | ).contiguous() # 0 is the VOID panoptic label 128 | 129 | image_shape = (image.shape[-2], image.shape[-1]) # h, w 130 | 131 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 132 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 133 | # Therefore it's important to use torch.Tensor. 134 | dataset_dict["image"] = image 135 | if sem_seg_gt is not None: 136 | dataset_dict["sem_seg"] = sem_seg_gt.long() 137 | 138 | if "annotations" in dataset_dict: 139 | raise ValueError("Pemantic segmentation dataset should not have 'annotations'.") 140 | 141 | # Prepare per-category binary masks 142 | pan_seg_gt = pan_seg_gt.numpy() 143 | instances = Instances(image_shape) 144 | classes = [] 145 | masks = [] 146 | for segment_info in segments_info: 147 | class_id = segment_info["category_id"] 148 | if not segment_info["iscrowd"]: 149 | classes.append(class_id) 150 | masks.append(pan_seg_gt == segment_info["id"]) 151 | 152 | classes = np.array(classes) 153 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 154 | if len(masks) == 0: 155 | # Some image does not have annotation (all ignored) 156 | instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) 157 | else: 158 | masks = BitMasks( 159 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 160 | ) 161 | instances.gt_masks = masks.tensor 162 | 163 | dataset_dict["instances"] = instances 164 | 165 | return dataset_dict 166 | -------------------------------------------------------------------------------- /mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.data import MetadataCatalog 11 | from detectron2.data import detection_utils as utils 12 | from detectron2.data import transforms as T 13 | from detectron2.projects.point_rend import ColorAugSSDTransform 14 | from detectron2.structures import BitMasks, Instances 15 | 16 | __all__ = ["MaskFormerSemanticDatasetMapper"] 17 | 18 | 19 | class MaskFormerSemanticDatasetMapper: 20 | """ 21 | A callable which takes a dataset dict in Detectron2 Dataset format, 22 | and map it into a format used by MaskFormer for semantic segmentation. 23 | 24 | The callable currently does the following: 25 | 26 | 1. Read the image from "file_name" 27 | 2. Applies geometric transforms to the image and annotation 28 | 3. Find and applies suitable cropping to the image and annotation 29 | 4. Prepare image and annotation to Tensors 30 | """ 31 | 32 | @configurable 33 | def __init__( 34 | self, 35 | is_train=True, 36 | *, 37 | augmentations, 38 | image_format, 39 | ignore_label, 40 | size_divisibility, 41 | ): 42 | """ 43 | NOTE: this interface is experimental. 44 | Args: 45 | is_train: for training or inference 46 | augmentations: a list of augmentations or deterministic transforms to apply 47 | image_format: an image format supported by :func:`detection_utils.read_image`. 48 | ignore_label: the label that is ignored to evaluation 49 | size_divisibility: pad image size to be divisible by this value 50 | """ 51 | self.is_train = is_train 52 | self.tfm_gens = augmentations 53 | self.img_format = image_format 54 | self.ignore_label = ignore_label 55 | self.size_divisibility = size_divisibility 56 | 57 | logger = logging.getLogger(__name__) 58 | mode = "training" if is_train else "inference" 59 | logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}") 60 | 61 | @classmethod 62 | def from_config(cls, cfg, is_train=True): 63 | # Build augmentation 64 | augs = [ 65 | T.ResizeShortestEdge( 66 | cfg.INPUT.MIN_SIZE_TRAIN, 67 | cfg.INPUT.MAX_SIZE_TRAIN, 68 | cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, 69 | ) 70 | ] 71 | if cfg.INPUT.CROP.ENABLED: 72 | augs.append( 73 | T.RandomCrop_CategoryAreaConstraint( 74 | cfg.INPUT.CROP.TYPE, 75 | cfg.INPUT.CROP.SIZE, 76 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, 77 | cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 78 | ) 79 | ) 80 | if cfg.INPUT.COLOR_AUG_SSD: 81 | augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT)) 82 | augs.append(T.RandomFlip()) 83 | 84 | # Assume always applies to the training set. 85 | dataset_names = cfg.DATASETS.TRAIN 86 | meta = MetadataCatalog.get(dataset_names[0]) 87 | ignore_label = meta.ignore_label 88 | 89 | ret = { 90 | "is_train": is_train, 91 | "augmentations": augs, 92 | "image_format": cfg.INPUT.FORMAT, 93 | "ignore_label": ignore_label, 94 | "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY, 95 | } 96 | return ret 97 | 98 | def __call__(self, dataset_dict): 99 | """ 100 | Args: 101 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 102 | 103 | Returns: 104 | dict: a format that builtin models in detectron2 accept 105 | """ 106 | assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!" 107 | 108 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 109 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 110 | utils.check_image_size(dataset_dict, image) 111 | 112 | if "sem_seg_file_name" in dataset_dict: 113 | # PyTorch transformation not implemented for uint16, so converting it to double first 114 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double") 115 | else: 116 | sem_seg_gt = None 117 | 118 | if sem_seg_gt is None: 119 | raise ValueError( 120 | "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( 121 | dataset_dict["file_name"] 122 | ) 123 | ) 124 | 125 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 126 | aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) 127 | image = aug_input.image 128 | sem_seg_gt = aug_input.sem_seg 129 | 130 | # Pad image and segmentation label here! 131 | image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 132 | if sem_seg_gt is not None: 133 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 134 | 135 | if self.size_divisibility > 0: 136 | image_size = (image.shape[-2], image.shape[-1]) 137 | padding_size = [ 138 | 0, 139 | self.size_divisibility - image_size[1], 140 | 0, 141 | self.size_divisibility - image_size[0], 142 | ] 143 | image = F.pad(image, padding_size, value=128).contiguous() 144 | if sem_seg_gt is not None: 145 | sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() 146 | 147 | image_shape = (image.shape[-2], image.shape[-1]) # h, w 148 | 149 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 150 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 151 | # Therefore it's important to use torch.Tensor. 152 | dataset_dict["image"] = image 153 | 154 | if sem_seg_gt is not None: 155 | dataset_dict["sem_seg"] = sem_seg_gt.long() 156 | 157 | if "annotations" in dataset_dict: 158 | raise ValueError("Semantic segmentation dataset should not have 'annotations'.") 159 | 160 | # Prepare per-category binary masks 161 | if sem_seg_gt is not None: 162 | sem_seg_gt = sem_seg_gt.numpy() 163 | instances = Instances(image_shape) 164 | classes = np.unique(sem_seg_gt) 165 | # remove ignored region 166 | classes = classes[classes != self.ignore_label] 167 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 168 | 169 | masks = [] 170 | for class_id in classes: 171 | masks.append(sem_seg_gt == class_id) 172 | 173 | if len(masks) == 0: 174 | # Some image does not have annotation (all ignored) 175 | instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) 176 | else: 177 | masks = BitMasks( 178 | torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) 179 | ) 180 | instances.gt_masks = masks.tensor 181 | 182 | dataset_dict["instances"] = instances 183 | 184 | return dataset_dict 185 | -------------------------------------------------------------------------------- /mask_former/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import ( 3 | register_ade20k_full, 4 | register_ade20k_panoptic, 5 | register_coco_stuff_10k, 6 | register_mapillary_vistas, 7 | ) 8 | -------------------------------------------------------------------------------- /mask_former/mask_former_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from detectron2.config import configurable 9 | from detectron2.data import MetadataCatalog 10 | from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head 11 | from detectron2.modeling.backbone import Backbone 12 | from detectron2.modeling.postprocessing import sem_seg_postprocess 13 | from detectron2.structures import ImageList 14 | 15 | from .modeling.criterion import SetCriterion 16 | from .modeling.matcher import HungarianMatcher 17 | 18 | 19 | @META_ARCH_REGISTRY.register() 20 | class MaskFormer(nn.Module): 21 | """ 22 | Main class for mask classification semantic segmentation architectures. 23 | """ 24 | 25 | @configurable 26 | def __init__( 27 | self, 28 | *, 29 | backbone: Backbone, 30 | sem_seg_head: nn.Module, 31 | criterion: nn.Module, 32 | num_queries: int, 33 | panoptic_on: bool, 34 | object_mask_threshold: float, 35 | overlap_threshold: float, 36 | metadata, 37 | size_divisibility: int, 38 | sem_seg_postprocess_before_inference: bool, 39 | pixel_mean: Tuple[float], 40 | pixel_std: Tuple[float], 41 | ): 42 | """ 43 | Args: 44 | backbone: a backbone module, must follow detectron2's backbone interface 45 | sem_seg_head: a module that predicts semantic segmentation from backbone features 46 | criterion: a module that defines the loss 47 | num_queries: int, number of queries 48 | panoptic_on: bool, whether to output panoptic segmentation prediction 49 | object_mask_threshold: float, threshold to filter query based on classification score 50 | for panoptic segmentation inference 51 | overlap_threshold: overlap threshold used in general inference for panoptic segmentation 52 | metadata: dataset meta, get `thing` and `stuff` category names for panoptic 53 | segmentation inference 54 | size_divisibility: Some backbones require the input height and width to be divisible by a 55 | specific integer. We can use this to override such requirement. 56 | sem_seg_postprocess_before_inference: whether to resize the prediction back 57 | to original input size before semantic segmentation inference or after. 58 | For high-resolution dataset like Mapillary, resizing predictions before 59 | inference will cause OOM error. 60 | pixel_mean, pixel_std: list or tuple with #channels element, representing 61 | the per-channel mean and std to be used to normalize the input image 62 | """ 63 | super().__init__() 64 | self.backbone = backbone 65 | self.sem_seg_head = sem_seg_head 66 | self.criterion = criterion 67 | self.num_queries = num_queries 68 | self.overlap_threshold = overlap_threshold 69 | self.panoptic_on = panoptic_on 70 | self.object_mask_threshold = object_mask_threshold 71 | self.metadata = metadata 72 | if size_divisibility < 0: 73 | # use backbone size_divisibility if not set 74 | size_divisibility = self.backbone.size_divisibility 75 | self.size_divisibility = size_divisibility 76 | self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference 77 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 78 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 79 | 80 | @classmethod 81 | def from_config(cls, cfg): 82 | backbone = build_backbone(cfg) 83 | sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) 84 | 85 | # Loss parameters: 86 | deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 87 | no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT 88 | dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT 89 | mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT 90 | 91 | # building criterion 92 | matcher = HungarianMatcher( 93 | cost_class=1, 94 | cost_mask=mask_weight, 95 | cost_dice=dice_weight, 96 | ) 97 | 98 | weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight} 99 | if deep_supervision: 100 | dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS 101 | aux_weight_dict = {} 102 | for i in range(dec_layers - 1): 103 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 104 | weight_dict.update(aux_weight_dict) 105 | 106 | losses = ["labels", "masks"] 107 | 108 | criterion = SetCriterion( 109 | sem_seg_head.num_classes, 110 | matcher=matcher, 111 | weight_dict=weight_dict, 112 | eos_coef=no_object_weight, 113 | losses=losses, 114 | ) 115 | 116 | return { 117 | "backbone": backbone, 118 | "sem_seg_head": sem_seg_head, 119 | "criterion": criterion, 120 | "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, 121 | "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON, 122 | "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, 123 | "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, 124 | "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), 125 | "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, 126 | "sem_seg_postprocess_before_inference": ( 127 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE 128 | or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON 129 | ), 130 | "pixel_mean": cfg.MODEL.PIXEL_MEAN, 131 | "pixel_std": cfg.MODEL.PIXEL_STD, 132 | } 133 | 134 | @property 135 | def device(self): 136 | return self.pixel_mean.device 137 | 138 | def forward(self, batched_inputs): 139 | """ 140 | Args: 141 | batched_inputs: a list, batched outputs of :class:`DatasetMapper`. 142 | Each item in the list contains the inputs for one image. 143 | For now, each item in the list is a dict that contains: 144 | * "image": Tensor, image in (C, H, W) format. 145 | * "instances": per-region ground truth 146 | * Other information that's included in the original dicts, such as: 147 | "height", "width" (int): the output resolution of the model (may be different 148 | from input resolution), used in inference. 149 | Returns: 150 | list[dict]: 151 | each dict has the results for one image. The dict contains the following keys: 152 | 153 | * "sem_seg": 154 | A Tensor that represents the 155 | per-pixel segmentation prediced by the head. 156 | The prediction has shape KxHxW that represents the logits of 157 | each class for each pixel. 158 | * "panoptic_seg": 159 | A tuple that represent panoptic output 160 | panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. 161 | segments_info (list[dict]): Describe each segment in `panoptic_seg`. 162 | Each dict contains keys "id", "category_id", "isthing". 163 | """ 164 | images = [x["image"].to(self.device) for x in batched_inputs] 165 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 166 | images = ImageList.from_tensors(images, self.size_divisibility) 167 | 168 | features = self.backbone(images.tensor) 169 | outputs = self.sem_seg_head(features) 170 | 171 | if self.training: 172 | # mask classification target 173 | if "instances" in batched_inputs[0]: 174 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 175 | targets = self.prepare_targets(gt_instances, images) 176 | else: 177 | targets = None 178 | 179 | # bipartite matching-based loss 180 | losses = self.criterion(outputs, targets) 181 | 182 | for k in list(losses.keys()): 183 | if k in self.criterion.weight_dict: 184 | losses[k] *= self.criterion.weight_dict[k] 185 | else: 186 | # remove this loss if not specified in `weight_dict` 187 | losses.pop(k) 188 | 189 | return losses 190 | else: 191 | mask_cls_results = outputs["pred_logits"] 192 | mask_pred_results = outputs["pred_masks"] 193 | # upsample masks 194 | mask_pred_results = F.interpolate( 195 | mask_pred_results, 196 | size=(images.tensor.shape[-2], images.tensor.shape[-1]), 197 | mode="bilinear", 198 | align_corners=False, 199 | ) 200 | 201 | processed_results = [] 202 | for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( 203 | mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes 204 | ): 205 | height = input_per_image.get("height", image_size[0]) 206 | width = input_per_image.get("width", image_size[1]) 207 | 208 | if self.sem_seg_postprocess_before_inference: 209 | mask_pred_result = sem_seg_postprocess( 210 | mask_pred_result, image_size, height, width 211 | ) 212 | 213 | # semantic segmentation inference 214 | r = self.semantic_inference(mask_cls_result, mask_pred_result) 215 | if not self.sem_seg_postprocess_before_inference: 216 | r = sem_seg_postprocess(r, image_size, height, width) 217 | processed_results.append({"sem_seg": r}) 218 | 219 | # panoptic segmentation inference 220 | if self.panoptic_on: 221 | panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result) 222 | processed_results[-1]["panoptic_seg"] = panoptic_r 223 | 224 | return processed_results 225 | 226 | def prepare_targets(self, targets, images): 227 | h, w = images.tensor.shape[-2:] 228 | new_targets = [] 229 | for targets_per_image in targets: 230 | # pad gt 231 | gt_masks = targets_per_image.gt_masks 232 | padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device) 233 | padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks 234 | new_targets.append( 235 | { 236 | "labels": targets_per_image.gt_classes, 237 | "masks": padded_masks, 238 | } 239 | ) 240 | return new_targets 241 | 242 | def semantic_inference(self, mask_cls, mask_pred): 243 | mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] 244 | mask_pred = mask_pred.sigmoid() 245 | semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) 246 | return semseg 247 | 248 | def panoptic_inference(self, mask_cls, mask_pred): 249 | scores, labels = F.softmax(mask_cls, dim=-1).max(-1) 250 | mask_pred = mask_pred.sigmoid() 251 | 252 | keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) 253 | cur_scores = scores[keep] 254 | cur_classes = labels[keep] 255 | cur_masks = mask_pred[keep] 256 | cur_mask_cls = mask_cls[keep] 257 | cur_mask_cls = cur_mask_cls[:, :-1] 258 | 259 | cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks 260 | 261 | h, w = cur_masks.shape[-2:] 262 | panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) 263 | segments_info = [] 264 | 265 | current_segment_id = 0 266 | 267 | if cur_masks.shape[0] == 0: 268 | # We didn't detect any mask :( 269 | return panoptic_seg, segments_info 270 | else: 271 | # take argmax 272 | cur_mask_ids = cur_prob_masks.argmax(0) 273 | stuff_memory_list = {} 274 | for k in range(cur_classes.shape[0]): 275 | pred_class = cur_classes[k].item() 276 | isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() 277 | mask = cur_mask_ids == k 278 | mask_area = mask.sum().item() 279 | original_area = (cur_masks[k] >= 0.5).sum().item() 280 | 281 | if mask_area > 0 and original_area > 0: 282 | if mask_area / original_area < self.overlap_threshold: 283 | continue 284 | 285 | # merge stuff regions 286 | if not isthing: 287 | if int(pred_class) in stuff_memory_list.keys(): 288 | panoptic_seg[mask] = stuff_memory_list[int(pred_class)] 289 | continue 290 | else: 291 | stuff_memory_list[int(pred_class)] = current_segment_id + 1 292 | 293 | current_segment_id += 1 294 | panoptic_seg[mask] = current_segment_id 295 | 296 | segments_info.append( 297 | { 298 | "id": current_segment_id, 299 | "isthing": bool(isthing), 300 | "category_id": int(pred_class), 301 | } 302 | ) 303 | 304 | return panoptic_seg, segments_info 305 | -------------------------------------------------------------------------------- /mask_former/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .backbone.swin import D2SwinTransformer 3 | from .heads.mask_former_head import MaskFormerHead 4 | from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead 5 | from .heads.pixel_decoder import BasePixelDecoder 6 | -------------------------------------------------------------------------------- /mask_former/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/modeling/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | """ 4 | MaskFormer criterion. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from detectron2.utils.comm import get_world_size 11 | 12 | from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list 13 | 14 | 15 | def dice_loss(inputs, targets, num_masks): 16 | """ 17 | Compute the DICE loss, similar to generalized IOU for masks 18 | Args: 19 | inputs: A float tensor of arbitrary shape. 20 | The predictions for each example. 21 | targets: A float tensor with the same shape as inputs. Stores the binary 22 | classification label for each element in inputs 23 | (0 for the negative class and 1 for the positive class). 24 | """ 25 | inputs = inputs.sigmoid() 26 | inputs = inputs.flatten(1) 27 | numerator = 2 * (inputs * targets).sum(-1) 28 | denominator = inputs.sum(-1) + targets.sum(-1) 29 | loss = 1 - (numerator + 1) / (denominator + 1) 30 | return loss.sum() / num_masks 31 | 32 | 33 | def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2): 34 | """ 35 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 36 | Args: 37 | inputs: A float tensor of arbitrary shape. 38 | The predictions for each example. 39 | targets: A float tensor with the same shape as inputs. Stores the binary 40 | classification label for each element in inputs 41 | (0 for the negative class and 1 for the positive class). 42 | alpha: (optional) Weighting factor in range (0,1) to balance 43 | positive vs negative examples. Default = -1 (no weighting). 44 | gamma: Exponent of the modulating factor (1 - p_t) to 45 | balance easy vs hard examples. 46 | Returns: 47 | Loss tensor 48 | """ 49 | prob = inputs.sigmoid() 50 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 51 | p_t = prob * targets + (1 - prob) * (1 - targets) 52 | loss = ce_loss * ((1 - p_t) ** gamma) 53 | 54 | if alpha >= 0: 55 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 56 | loss = alpha_t * loss 57 | 58 | return loss.mean(1).sum() / num_masks 59 | 60 | 61 | class SetCriterion(nn.Module): 62 | """This class computes the loss for DETR. 63 | The process happens in two steps: 64 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 65 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 66 | """ 67 | 68 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): 69 | """Create the criterion. 70 | Parameters: 71 | num_classes: number of object categories, omitting the special no-object category 72 | matcher: module able to compute a matching between targets and proposals 73 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 74 | eos_coef: relative classification weight applied to the no-object category 75 | losses: list of all the losses to be applied. See get_loss for list of available losses. 76 | """ 77 | super().__init__() 78 | self.num_classes = num_classes 79 | self.matcher = matcher 80 | self.weight_dict = weight_dict 81 | self.eos_coef = eos_coef 82 | self.losses = losses 83 | empty_weight = torch.ones(self.num_classes + 1) 84 | empty_weight[-1] = self.eos_coef 85 | self.register_buffer("empty_weight", empty_weight) 86 | 87 | def loss_labels(self, outputs, targets, indices, num_masks): 88 | """Classification loss (NLL) 89 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 90 | """ 91 | assert "pred_logits" in outputs 92 | src_logits = outputs["pred_logits"] 93 | 94 | idx = self._get_src_permutation_idx(indices) 95 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 96 | target_classes = torch.full( 97 | src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device 98 | ) 99 | target_classes[idx] = target_classes_o 100 | 101 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) 102 | losses = {"loss_ce": loss_ce} 103 | return losses 104 | 105 | def loss_masks(self, outputs, targets, indices, num_masks): 106 | """Compute the losses related to the masks: the focal loss and the dice loss. 107 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 108 | """ 109 | assert "pred_masks" in outputs 110 | 111 | src_idx = self._get_src_permutation_idx(indices) 112 | tgt_idx = self._get_tgt_permutation_idx(indices) 113 | src_masks = outputs["pred_masks"] 114 | src_masks = src_masks[src_idx] 115 | masks = [t["masks"] for t in targets] 116 | # TODO use valid to mask invalid areas due to padding in loss 117 | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() 118 | target_masks = target_masks.to(src_masks) 119 | target_masks = target_masks[tgt_idx] 120 | 121 | # upsample predictions to the target size 122 | src_masks = F.interpolate( 123 | src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False 124 | ) 125 | src_masks = src_masks[:, 0].flatten(1) 126 | 127 | target_masks = target_masks.flatten(1) 128 | target_masks = target_masks.view(src_masks.shape) 129 | losses = { 130 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks), 131 | "loss_dice": dice_loss(src_masks, target_masks, num_masks), 132 | } 133 | return losses 134 | 135 | def _get_src_permutation_idx(self, indices): 136 | # permute predictions following indices 137 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 138 | src_idx = torch.cat([src for (src, _) in indices]) 139 | return batch_idx, src_idx 140 | 141 | def _get_tgt_permutation_idx(self, indices): 142 | # permute targets following indices 143 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 144 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 145 | return batch_idx, tgt_idx 146 | 147 | def get_loss(self, loss, outputs, targets, indices, num_masks): 148 | loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} 149 | assert loss in loss_map, f"do you really want to compute {loss} loss?" 150 | return loss_map[loss](outputs, targets, indices, num_masks) 151 | 152 | def forward(self, outputs, targets): 153 | """This performs the loss computation. 154 | Parameters: 155 | outputs: dict of tensors, see the output specification of the model for the format 156 | targets: list of dicts, such that len(targets) == batch_size. 157 | The expected keys in each dict depends on the losses applied, see each loss' doc 158 | """ 159 | outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} 160 | 161 | # Retrieve the matching between the outputs of the last layer and the targets 162 | indices = self.matcher(outputs_without_aux, targets) 163 | 164 | # Compute the average number of target boxes accross all nodes, for normalization purposes 165 | num_masks = sum(len(t["labels"]) for t in targets) 166 | num_masks = torch.as_tensor( 167 | [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device 168 | ) 169 | if is_dist_avail_and_initialized(): 170 | torch.distributed.all_reduce(num_masks) 171 | num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() 172 | 173 | # Compute all the requested losses 174 | losses = {} 175 | for loss in self.losses: 176 | losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) 177 | 178 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 179 | if "aux_outputs" in outputs: 180 | for i, aux_outputs in enumerate(outputs["aux_outputs"]): 181 | indices = self.matcher(aux_outputs, targets) 182 | for loss in self.losses: 183 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) 184 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 185 | losses.update(l_dict) 186 | 187 | return losses 188 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/mask_former_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from copy import deepcopy 4 | from typing import Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import fvcore.nn.weight_init as weight_init 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from detectron2.config import configurable 11 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 12 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 13 | 14 | from ..transformer.transformer_predictor import TransformerPredictor 15 | from .pixel_decoder import build_pixel_decoder 16 | 17 | 18 | @SEM_SEG_HEADS_REGISTRY.register() 19 | class MaskFormerHead(nn.Module): 20 | 21 | _version = 2 22 | 23 | def _load_from_state_dict( 24 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 25 | ): 26 | version = local_metadata.get("version", None) 27 | if version is None or version < 2: 28 | # Do not warn if train from scratch 29 | scratch = True 30 | logger = logging.getLogger(__name__) 31 | for k in list(state_dict.keys()): 32 | newk = k 33 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 34 | newk = k.replace(prefix, prefix + "pixel_decoder.") 35 | # logger.debug(f"{k} ==> {newk}") 36 | if newk != k: 37 | state_dict[newk] = state_dict[k] 38 | del state_dict[k] 39 | scratch = False 40 | 41 | if not scratch: 42 | logger.warning( 43 | f"Weight format of {self.__class__.__name__} have changed! " 44 | "Please upgrade your models. Applying automatic conversion now ..." 45 | ) 46 | 47 | @configurable 48 | def __init__( 49 | self, 50 | input_shape: Dict[str, ShapeSpec], 51 | *, 52 | num_classes: int, 53 | pixel_decoder: nn.Module, 54 | loss_weight: float = 1.0, 55 | ignore_value: int = -1, 56 | # extra parameters 57 | transformer_predictor: nn.Module, 58 | transformer_in_feature: str, 59 | ): 60 | """ 61 | NOTE: this interface is experimental. 62 | Args: 63 | input_shape: shapes (channels and stride) of the input features 64 | num_classes: number of classes to predict 65 | pixel_decoder: the pixel decoder module 66 | loss_weight: loss weight 67 | ignore_value: category id to be ignored during training. 68 | transformer_predictor: the transformer decoder that makes prediction 69 | transformer_in_feature: input feature name to the transformer_predictor 70 | """ 71 | super().__init__() 72 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 73 | self.in_features = [k for k, v in input_shape] 74 | feature_strides = [v.stride for k, v in input_shape] 75 | feature_channels = [v.channels for k, v in input_shape] 76 | 77 | self.ignore_value = ignore_value 78 | self.common_stride = 4 79 | self.loss_weight = loss_weight 80 | 81 | self.pixel_decoder = pixel_decoder 82 | self.predictor = transformer_predictor 83 | self.transformer_in_feature = transformer_in_feature 84 | 85 | self.num_classes = num_classes 86 | 87 | @classmethod 88 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 89 | return { 90 | "input_shape": { 91 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 92 | }, 93 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 94 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 95 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 96 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 97 | "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, 98 | "transformer_predictor": TransformerPredictor( 99 | cfg, 100 | cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 101 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" 102 | else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, 103 | mask_classification=True, 104 | ), 105 | } 106 | 107 | def forward(self, features): 108 | return self.layers(features) 109 | 110 | def layers(self, features): 111 | mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) 112 | if self.transformer_in_feature == "transformer_encoder": 113 | assert ( 114 | transformer_encoder_features is not None 115 | ), "Please use the TransformerEncoderPixelDecoder." 116 | predictions = self.predictor(transformer_encoder_features, mask_features) 117 | else: 118 | predictions = self.predictor(features[self.transformer_in_feature], mask_features) 119 | return predictions 120 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/per_pixel_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from typing import Callable, Dict, List, Optional, Tuple, Union 4 | 5 | import fvcore.nn.weight_init as weight_init 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 11 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 12 | 13 | from ..transformer.transformer_predictor import TransformerPredictor 14 | from .pixel_decoder import build_pixel_decoder 15 | 16 | 17 | @SEM_SEG_HEADS_REGISTRY.register() 18 | class PerPixelBaselineHead(nn.Module): 19 | 20 | _version = 2 21 | 22 | def _load_from_state_dict( 23 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 24 | ): 25 | version = local_metadata.get("version", None) 26 | if version is None or version < 2: 27 | logger = logging.getLogger(__name__) 28 | # Do not warn if train from scratch 29 | scratch = True 30 | logger = logging.getLogger(__name__) 31 | for k in list(state_dict.keys()): 32 | newk = k 33 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 34 | newk = k.replace(prefix, prefix + "pixel_decoder.") 35 | # logger.warning(f"{k} ==> {newk}") 36 | if newk != k: 37 | state_dict[newk] = state_dict[k] 38 | del state_dict[k] 39 | scratch = False 40 | 41 | if not scratch: 42 | logger.warning( 43 | f"Weight format of {self.__class__.__name__} have changed! " 44 | "Please upgrade your models. Applying automatic conversion now ..." 45 | ) 46 | 47 | @configurable 48 | def __init__( 49 | self, 50 | input_shape: Dict[str, ShapeSpec], 51 | *, 52 | num_classes: int, 53 | pixel_decoder: nn.Module, 54 | loss_weight: float = 1.0, 55 | ignore_value: int = -1, 56 | ): 57 | """ 58 | NOTE: this interface is experimental. 59 | Args: 60 | input_shape: shapes (channels and stride) of the input features 61 | num_classes: number of classes to predict 62 | pixel_decoder: the pixel decoder module 63 | loss_weight: loss weight 64 | ignore_value: category id to be ignored during training. 65 | """ 66 | super().__init__() 67 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 68 | self.in_features = [k for k, v in input_shape] 69 | feature_strides = [v.stride for k, v in input_shape] 70 | feature_channels = [v.channels for k, v in input_shape] 71 | 72 | self.ignore_value = ignore_value 73 | self.common_stride = 4 74 | self.loss_weight = loss_weight 75 | 76 | self.pixel_decoder = pixel_decoder 77 | self.predictor = Conv2d( 78 | self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0 79 | ) 80 | weight_init.c2_msra_fill(self.predictor) 81 | 82 | @classmethod 83 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 84 | return { 85 | "input_shape": { 86 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 87 | }, 88 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 89 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 90 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 91 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 92 | } 93 | 94 | def forward(self, features, targets=None): 95 | """ 96 | Returns: 97 | In training, returns (None, dict of losses) 98 | In inference, returns (CxHxW logits, {}) 99 | """ 100 | x = self.layers(features) 101 | if self.training: 102 | return None, self.losses(x, targets) 103 | else: 104 | x = F.interpolate( 105 | x, scale_factor=self.common_stride, mode="bilinear", align_corners=False 106 | ) 107 | return x, {} 108 | 109 | def layers(self, features): 110 | x, _ = self.pixel_decoder.forward_features(features) 111 | x = self.predictor(x) 112 | return x 113 | 114 | def losses(self, predictions, targets): 115 | predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 116 | predictions = F.interpolate( 117 | predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False 118 | ) 119 | loss = F.cross_entropy( 120 | predictions, targets, reduction="mean", ignore_index=self.ignore_value 121 | ) 122 | losses = {"loss_sem_seg": loss * self.loss_weight} 123 | return losses 124 | 125 | 126 | @SEM_SEG_HEADS_REGISTRY.register() 127 | class PerPixelBaselinePlusHead(PerPixelBaselineHead): 128 | def _load_from_state_dict( 129 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 130 | ): 131 | version = local_metadata.get("version", None) 132 | if version is None or version < 2: 133 | # Do not warn if train from scratch 134 | scratch = True 135 | logger = logging.getLogger(__name__) 136 | for k in list(state_dict.keys()): 137 | newk = k 138 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 139 | newk = k.replace(prefix, prefix + "pixel_decoder.") 140 | logger.debug(f"{k} ==> {newk}") 141 | if newk != k: 142 | state_dict[newk] = state_dict[k] 143 | del state_dict[k] 144 | scratch = False 145 | 146 | if not scratch: 147 | logger.warning( 148 | f"Weight format of {self.__class__.__name__} have changed! " 149 | "Please upgrade your models. Applying automatic conversion now ..." 150 | ) 151 | 152 | @configurable 153 | def __init__( 154 | self, 155 | input_shape: Dict[str, ShapeSpec], 156 | *, 157 | # extra parameters 158 | transformer_predictor: nn.Module, 159 | transformer_in_feature: str, 160 | deep_supervision: bool, 161 | # inherit parameters 162 | num_classes: int, 163 | pixel_decoder: nn.Module, 164 | loss_weight: float = 1.0, 165 | ignore_value: int = -1, 166 | ): 167 | """ 168 | NOTE: this interface is experimental. 169 | Args: 170 | input_shape: shapes (channels and stride) of the input features 171 | transformer_predictor: the transformer decoder that makes prediction 172 | transformer_in_feature: input feature name to the transformer_predictor 173 | deep_supervision: whether or not to add supervision to the output of 174 | every transformer decoder layer 175 | num_classes: number of classes to predict 176 | pixel_decoder: the pixel decoder module 177 | loss_weight: loss weight 178 | ignore_value: category id to be ignored during training. 179 | """ 180 | super().__init__( 181 | input_shape, 182 | num_classes=num_classes, 183 | pixel_decoder=pixel_decoder, 184 | loss_weight=loss_weight, 185 | ignore_value=ignore_value, 186 | ) 187 | 188 | del self.predictor 189 | 190 | self.predictor = transformer_predictor 191 | self.transformer_in_feature = transformer_in_feature 192 | self.deep_supervision = deep_supervision 193 | 194 | @classmethod 195 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 196 | ret = super().from_config(cfg, input_shape) 197 | ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE 198 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder": 199 | in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 200 | else: 201 | in_channels = input_shape[ret["transformer_in_feature"]].channels 202 | ret["transformer_predictor"] = TransformerPredictor( 203 | cfg, in_channels, mask_classification=False 204 | ) 205 | ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 206 | return ret 207 | 208 | def forward(self, features, targets=None): 209 | """ 210 | Returns: 211 | In training, returns (None, dict of losses) 212 | In inference, returns (CxHxW logits, {}) 213 | """ 214 | x, aux_outputs = self.layers(features) 215 | if self.training: 216 | if self.deep_supervision: 217 | losses = self.losses(x, targets) 218 | for i, aux_output in enumerate(aux_outputs): 219 | losses["loss_sem_seg" + f"_{i}"] = self.losses( 220 | aux_output["pred_masks"], targets 221 | )["loss_sem_seg"] 222 | return None, losses 223 | else: 224 | return None, self.losses(x, targets) 225 | else: 226 | x = F.interpolate( 227 | x, scale_factor=self.common_stride, mode="bilinear", align_corners=False 228 | ) 229 | return x, {} 230 | 231 | def layers(self, features): 232 | mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) 233 | if self.transformer_in_feature == "transformer_encoder": 234 | assert ( 235 | transformer_encoder_features is not None 236 | ), "Please use the TransformerEncoderPixelDecoder." 237 | predictions = self.predictor(transformer_encoder_features, mask_features) 238 | else: 239 | predictions = self.predictor(features[self.transformer_in_feature], mask_features) 240 | if self.deep_supervision: 241 | return predictions["pred_masks"], predictions["aux_outputs"] 242 | else: 243 | return predictions["pred_masks"], None 244 | -------------------------------------------------------------------------------- /mask_former/modeling/heads/pixel_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | from typing import Callable, Dict, List, Optional, Tuple, Union 4 | 5 | import fvcore.nn.weight_init as weight_init 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from detectron2.config import configurable 10 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 11 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 12 | 13 | from ..transformer.position_encoding import PositionEmbeddingSine 14 | from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer 15 | 16 | 17 | def build_pixel_decoder(cfg, input_shape): 18 | """ 19 | Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. 20 | """ 21 | name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME 22 | model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) 23 | forward_features = getattr(model, "forward_features", None) 24 | if not callable(forward_features): 25 | raise ValueError( 26 | "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " 27 | f"Please implement forward_features for {name} to only return mask features." 28 | ) 29 | return model 30 | 31 | 32 | @SEM_SEG_HEADS_REGISTRY.register() 33 | class BasePixelDecoder(nn.Module): 34 | @configurable 35 | def __init__( 36 | self, 37 | input_shape: Dict[str, ShapeSpec], 38 | *, 39 | conv_dim: int, 40 | mask_dim: int, 41 | norm: Optional[Union[str, Callable]] = None, 42 | ): 43 | """ 44 | NOTE: this interface is experimental. 45 | Args: 46 | input_shape: shapes (channels and stride) of the input features 47 | conv_dims: number of output channels for the intermediate conv layers. 48 | mask_dim: number of output channels for the final conv layer. 49 | norm (str or callable): normalization for all conv layers 50 | """ 51 | super().__init__() 52 | 53 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 54 | self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" 55 | feature_channels = [v.channels for k, v in input_shape] 56 | 57 | lateral_convs = [] 58 | output_convs = [] 59 | 60 | use_bias = norm == "" 61 | for idx, in_channels in enumerate(feature_channels): 62 | if idx == len(self.in_features) - 1: 63 | output_norm = get_norm(norm, conv_dim) 64 | output_conv = Conv2d( 65 | in_channels, 66 | conv_dim, 67 | kernel_size=3, 68 | stride=1, 69 | padding=1, 70 | bias=use_bias, 71 | norm=output_norm, 72 | activation=F.relu, 73 | ) 74 | weight_init.c2_xavier_fill(output_conv) 75 | self.add_module("layer_{}".format(idx + 1), output_conv) 76 | 77 | lateral_convs.append(None) 78 | output_convs.append(output_conv) 79 | else: 80 | lateral_norm = get_norm(norm, conv_dim) 81 | output_norm = get_norm(norm, conv_dim) 82 | 83 | lateral_conv = Conv2d( 84 | in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm 85 | ) 86 | output_conv = Conv2d( 87 | conv_dim, 88 | conv_dim, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1, 92 | bias=use_bias, 93 | norm=output_norm, 94 | activation=F.relu, 95 | ) 96 | weight_init.c2_xavier_fill(lateral_conv) 97 | weight_init.c2_xavier_fill(output_conv) 98 | self.add_module("adapter_{}".format(idx + 1), lateral_conv) 99 | self.add_module("layer_{}".format(idx + 1), output_conv) 100 | 101 | lateral_convs.append(lateral_conv) 102 | output_convs.append(output_conv) 103 | # Place convs into top-down order (from low to high resolution) 104 | # to make the top-down computation in forward clearer. 105 | self.lateral_convs = lateral_convs[::-1] 106 | self.output_convs = output_convs[::-1] 107 | 108 | self.mask_dim = mask_dim 109 | self.mask_features = Conv2d( 110 | conv_dim, 111 | mask_dim, 112 | kernel_size=3, 113 | stride=1, 114 | padding=1, 115 | ) 116 | weight_init.c2_xavier_fill(self.mask_features) 117 | 118 | @classmethod 119 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 120 | ret = {} 121 | ret["input_shape"] = { 122 | k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 123 | } 124 | ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 125 | ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM 126 | ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM 127 | return ret 128 | 129 | def forward_features(self, features): 130 | # Reverse feature maps into top-down order (from low to high resolution) 131 | for idx, f in enumerate(self.in_features[::-1]): 132 | x = features[f] 133 | lateral_conv = self.lateral_convs[idx] 134 | output_conv = self.output_convs[idx] 135 | if lateral_conv is None: 136 | y = output_conv(x) 137 | else: 138 | cur_fpn = lateral_conv(x) 139 | # Following FPN implementation, we use nearest upsampling here 140 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") 141 | y = output_conv(y) 142 | return self.mask_features(y), None 143 | 144 | def forward(self, features, targets=None): 145 | logger = logging.getLogger(__name__) 146 | logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") 147 | return self.forward_features(features) 148 | 149 | 150 | class TransformerEncoderOnly(nn.Module): 151 | def __init__( 152 | self, 153 | d_model=512, 154 | nhead=8, 155 | num_encoder_layers=6, 156 | dim_feedforward=2048, 157 | dropout=0.1, 158 | activation="relu", 159 | normalize_before=False, 160 | ): 161 | super().__init__() 162 | 163 | encoder_layer = TransformerEncoderLayer( 164 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 165 | ) 166 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 167 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 168 | 169 | self._reset_parameters() 170 | 171 | self.d_model = d_model 172 | self.nhead = nhead 173 | 174 | def _reset_parameters(self): 175 | for p in self.parameters(): 176 | if p.dim() > 1: 177 | nn.init.xavier_uniform_(p) 178 | 179 | def forward(self, src, mask, pos_embed): 180 | # flatten NxCxHxW to HWxNxC 181 | bs, c, h, w = src.shape 182 | src = src.flatten(2).permute(2, 0, 1) 183 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 184 | if mask is not None: 185 | mask = mask.flatten(1) 186 | 187 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 188 | return memory.permute(1, 2, 0).view(bs, c, h, w) 189 | 190 | 191 | @SEM_SEG_HEADS_REGISTRY.register() 192 | class TransformerEncoderPixelDecoder(BasePixelDecoder): 193 | @configurable 194 | def __init__( 195 | self, 196 | input_shape: Dict[str, ShapeSpec], 197 | *, 198 | transformer_dropout: float, 199 | transformer_nheads: int, 200 | transformer_dim_feedforward: int, 201 | transformer_enc_layers: int, 202 | transformer_pre_norm: bool, 203 | conv_dim: int, 204 | mask_dim: int, 205 | norm: Optional[Union[str, Callable]] = None, 206 | ): 207 | """ 208 | NOTE: this interface is experimental. 209 | Args: 210 | input_shape: shapes (channels and stride) of the input features 211 | transformer_dropout: dropout probability in transformer 212 | transformer_nheads: number of heads in transformer 213 | transformer_dim_feedforward: dimension of feedforward network 214 | transformer_enc_layers: number of transformer encoder layers 215 | transformer_pre_norm: whether to use pre-layernorm or not 216 | conv_dims: number of output channels for the intermediate conv layers. 217 | mask_dim: number of output channels for the final conv layer. 218 | norm (str or callable): normalization for all conv layers 219 | """ 220 | super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm) 221 | 222 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 223 | self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" 224 | feature_strides = [v.stride for k, v in input_shape] 225 | feature_channels = [v.channels for k, v in input_shape] 226 | 227 | in_channels = feature_channels[len(self.in_features) - 1] 228 | self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) 229 | weight_init.c2_xavier_fill(self.input_proj) 230 | self.transformer = TransformerEncoderOnly( 231 | d_model=conv_dim, 232 | dropout=transformer_dropout, 233 | nhead=transformer_nheads, 234 | dim_feedforward=transformer_dim_feedforward, 235 | num_encoder_layers=transformer_enc_layers, 236 | normalize_before=transformer_pre_norm, 237 | ) 238 | N_steps = conv_dim // 2 239 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 240 | 241 | # update layer 242 | use_bias = norm == "" 243 | output_norm = get_norm(norm, conv_dim) 244 | output_conv = Conv2d( 245 | conv_dim, 246 | conv_dim, 247 | kernel_size=3, 248 | stride=1, 249 | padding=1, 250 | bias=use_bias, 251 | norm=output_norm, 252 | activation=F.relu, 253 | ) 254 | weight_init.c2_xavier_fill(output_conv) 255 | delattr(self, "layer_{}".format(len(self.in_features))) 256 | self.add_module("layer_{}".format(len(self.in_features)), output_conv) 257 | self.output_convs[0] = output_conv 258 | 259 | @classmethod 260 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 261 | ret = super().from_config(cfg, input_shape) 262 | ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT 263 | ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS 264 | ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD 265 | ret[ 266 | "transformer_enc_layers" 267 | ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config 268 | ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM 269 | return ret 270 | 271 | def forward_features(self, features): 272 | # Reverse feature maps into top-down order (from low to high resolution) 273 | for idx, f in enumerate(self.in_features[::-1]): 274 | x = features[f] 275 | lateral_conv = self.lateral_convs[idx] 276 | output_conv = self.output_convs[idx] 277 | if lateral_conv is None: 278 | transformer = self.input_proj(x) 279 | pos = self.pe_layer(x) 280 | transformer = self.transformer(transformer, None, pos) 281 | y = output_conv(transformer) 282 | # save intermediate feature as input to Transformer decoder 283 | transformer_encoder_features = transformer 284 | else: 285 | cur_fpn = lateral_conv(x) 286 | # Following FPN implementation, we use nearest upsampling here 287 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") 288 | y = output_conv(y) 289 | return self.mask_features(y), transformer_encoder_features 290 | 291 | def forward(self, features, targets=None): 292 | logger = logging.getLogger(__name__) 293 | logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") 294 | return self.forward_features(features) 295 | -------------------------------------------------------------------------------- /mask_former/modeling/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py 3 | """ 4 | Modules to compute the matching cost and solve the corresponding LSAP. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from scipy.optimize import linear_sum_assignment 9 | from torch import nn 10 | 11 | 12 | def batch_dice_loss(inputs, targets): 13 | """ 14 | Compute the DICE loss, similar to generalized IOU for masks 15 | Args: 16 | inputs: A float tensor of arbitrary shape. 17 | The predictions for each example. 18 | targets: A float tensor with the same shape as inputs. Stores the binary 19 | classification label for each element in inputs 20 | (0 for the negative class and 1 for the positive class). 21 | """ 22 | inputs = inputs.sigmoid() 23 | inputs = inputs.flatten(1) 24 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) 25 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] 26 | loss = 1 - (numerator + 1) / (denominator + 1) 27 | return loss 28 | 29 | 30 | def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): 31 | """ 32 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 33 | Args: 34 | inputs: A float tensor of arbitrary shape. 35 | The predictions for each example. 36 | targets: A float tensor with the same shape as inputs. Stores the binary 37 | classification label for each element in inputs 38 | (0 for the negative class and 1 for the positive class). 39 | alpha: (optional) Weighting factor in range (0,1) to balance 40 | positive vs negative examples. Default = -1 (no weighting). 41 | gamma: Exponent of the modulating factor (1 - p_t) to 42 | balance easy vs hard examples. 43 | Returns: 44 | Loss tensor 45 | """ 46 | hw = inputs.shape[1] 47 | 48 | prob = inputs.sigmoid() 49 | focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits( 50 | inputs, torch.ones_like(inputs), reduction="none" 51 | ) 52 | focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits( 53 | inputs, torch.zeros_like(inputs), reduction="none" 54 | ) 55 | if alpha >= 0: 56 | focal_pos = focal_pos * alpha 57 | focal_neg = focal_neg * (1 - alpha) 58 | 59 | loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum( 60 | "nc,mc->nm", focal_neg, (1 - targets) 61 | ) 62 | 63 | return loss / hw 64 | 65 | 66 | class HungarianMatcher(nn.Module): 67 | """This class computes an assignment between the targets and the predictions of the network 68 | 69 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 70 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 71 | while the others are un-matched (and thus treated as non-objects). 72 | """ 73 | 74 | def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1): 75 | """Creates the matcher 76 | 77 | Params: 78 | cost_class: This is the relative weight of the classification error in the matching cost 79 | cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost 80 | cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost 81 | """ 82 | super().__init__() 83 | self.cost_class = cost_class 84 | self.cost_mask = cost_mask 85 | self.cost_dice = cost_dice 86 | assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" 87 | 88 | @torch.no_grad() 89 | def memory_efficient_forward(self, outputs, targets): 90 | """More memory-friendly matching""" 91 | bs, num_queries = outputs["pred_logits"].shape[:2] 92 | 93 | # Work out the mask padding size 94 | masks = [v["masks"] for v in targets] 95 | h_max = max([m.shape[1] for m in masks]) 96 | w_max = max([m.shape[2] for m in masks]) 97 | 98 | indices = [] 99 | 100 | # Iterate through batch size 101 | for b in range(bs): 102 | 103 | out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] 104 | out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] 105 | 106 | tgt_ids = targets[b]["labels"] 107 | # gt masks are already padded when preparing target 108 | tgt_mask = targets[b]["masks"].to(out_mask) 109 | 110 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 111 | # but approximate it in 1 - proba[target class]. 112 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 113 | cost_class = -out_prob[:, tgt_ids] 114 | 115 | # Downsample gt masks to save memory 116 | tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest") 117 | 118 | # Flatten spatial dimension 119 | out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W] 120 | tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W] 121 | 122 | # Compute the focal loss between masks 123 | cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask) 124 | 125 | # Compute the dice loss betwen masks 126 | cost_dice = batch_dice_loss(out_mask, tgt_mask) 127 | 128 | # Final cost matrix 129 | C = ( 130 | self.cost_mask * cost_mask 131 | + self.cost_class * cost_class 132 | + self.cost_dice * cost_dice 133 | ) 134 | C = C.reshape(num_queries, -1).cpu() 135 | 136 | indices.append(linear_sum_assignment(C)) 137 | return [ 138 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) 139 | for i, j in indices 140 | ] 141 | 142 | @torch.no_grad() 143 | def forward(self, outputs, targets): 144 | """Performs the matching 145 | 146 | Params: 147 | outputs: This is a dict that contains at least these entries: 148 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 149 | "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks 150 | 151 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 152 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 153 | objects in the target) containing the class labels 154 | "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks 155 | 156 | Returns: 157 | A list of size batch_size, containing tuples of (index_i, index_j) where: 158 | - index_i is the indices of the selected predictions (in order) 159 | - index_j is the indices of the corresponding selected targets (in order) 160 | For each batch element, it holds: 161 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 162 | """ 163 | return self.memory_efficient_forward(outputs, targets) 164 | 165 | def __repr__(self): 166 | head = "Matcher " + self.__class__.__name__ 167 | body = [ 168 | "cost_class: {}".format(self.cost_class), 169 | "cost_mask: {}".format(self.cost_mask), 170 | "cost_dice: {}".format(self.cost_dice), 171 | ] 172 | _repr_indent = 4 173 | lines = [head] + [" " * _repr_indent + line for line in body] 174 | return "\n".join(lines) 175 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py 3 | """ 4 | Transformer class. 5 | 6 | Copy-paste from torch.nn.Transformer with modifications: 7 | * positional encodings are passed in MHattention 8 | * extra LN at the end of encoder is removed 9 | * decoder returns a stack of activations from all decoding layers 10 | """ 11 | import copy 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import Tensor, nn 17 | 18 | 19 | class Transformer(nn.Module): 20 | def __init__( 21 | self, 22 | d_model=512, 23 | nhead=8, 24 | num_encoder_layers=6, 25 | num_decoder_layers=6, 26 | dim_feedforward=2048, 27 | dropout=0.1, 28 | activation="relu", 29 | normalize_before=False, 30 | return_intermediate_dec=False, 31 | ): 32 | super().__init__() 33 | 34 | encoder_layer = TransformerEncoderLayer( 35 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 36 | ) 37 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 38 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 39 | 40 | decoder_layer = TransformerDecoderLayer( 41 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 42 | ) 43 | decoder_norm = nn.LayerNorm(d_model) 44 | self.decoder = TransformerDecoder( 45 | decoder_layer, 46 | num_decoder_layers, 47 | decoder_norm, 48 | return_intermediate=return_intermediate_dec, 49 | ) 50 | 51 | self._reset_parameters() 52 | 53 | self.d_model = d_model 54 | self.nhead = nhead 55 | 56 | def _reset_parameters(self): 57 | for p in self.parameters(): 58 | if p.dim() > 1: 59 | nn.init.xavier_uniform_(p) 60 | 61 | def forward(self, src, mask, query_embed, pos_embed): 62 | # flatten NxCxHxW to HWxNxC 63 | bs, c, h, w = src.shape 64 | src = src.flatten(2).permute(2, 0, 1) 65 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 66 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 67 | if mask is not None: 68 | mask = mask.flatten(1) 69 | 70 | tgt = torch.zeros_like(query_embed) 71 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 72 | hs = self.decoder( 73 | tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed 74 | ) 75 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 76 | 77 | 78 | class TransformerEncoder(nn.Module): 79 | def __init__(self, encoder_layer, num_layers, norm=None): 80 | super().__init__() 81 | self.layers = _get_clones(encoder_layer, num_layers) 82 | self.num_layers = num_layers 83 | self.norm = norm 84 | 85 | def forward( 86 | self, 87 | src, 88 | mask: Optional[Tensor] = None, 89 | src_key_padding_mask: Optional[Tensor] = None, 90 | pos: Optional[Tensor] = None, 91 | ): 92 | output = src 93 | 94 | for layer in self.layers: 95 | output = layer( 96 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos 97 | ) 98 | 99 | if self.norm is not None: 100 | output = self.norm(output) 101 | 102 | return output 103 | 104 | 105 | class TransformerDecoder(nn.Module): 106 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 107 | super().__init__() 108 | self.layers = _get_clones(decoder_layer, num_layers) 109 | self.num_layers = num_layers 110 | self.norm = norm 111 | self.return_intermediate = return_intermediate 112 | 113 | def forward( 114 | self, 115 | tgt, 116 | memory, 117 | tgt_mask: Optional[Tensor] = None, 118 | memory_mask: Optional[Tensor] = None, 119 | tgt_key_padding_mask: Optional[Tensor] = None, 120 | memory_key_padding_mask: Optional[Tensor] = None, 121 | pos: Optional[Tensor] = None, 122 | query_pos: Optional[Tensor] = None, 123 | ): 124 | output = tgt 125 | 126 | intermediate = [] 127 | 128 | for layer in self.layers: 129 | output = layer( 130 | output, 131 | memory, 132 | tgt_mask=tgt_mask, 133 | memory_mask=memory_mask, 134 | tgt_key_padding_mask=tgt_key_padding_mask, 135 | memory_key_padding_mask=memory_key_padding_mask, 136 | pos=pos, 137 | query_pos=query_pos, 138 | ) 139 | if self.return_intermediate: 140 | intermediate.append(self.norm(output)) 141 | 142 | if self.norm is not None: 143 | output = self.norm(output) 144 | if self.return_intermediate: 145 | intermediate.pop() 146 | intermediate.append(output) 147 | 148 | if self.return_intermediate: 149 | return torch.stack(intermediate) 150 | 151 | return output.unsqueeze(0) 152 | 153 | 154 | class TransformerEncoderLayer(nn.Module): 155 | def __init__( 156 | self, 157 | d_model, 158 | nhead, 159 | dim_feedforward=2048, 160 | dropout=0.1, 161 | activation="relu", 162 | normalize_before=False, 163 | ): 164 | super().__init__() 165 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 166 | # Implementation of Feedforward model 167 | self.linear1 = nn.Linear(d_model, dim_feedforward) 168 | self.dropout = nn.Dropout(dropout) 169 | self.linear2 = nn.Linear(dim_feedforward, d_model) 170 | 171 | self.norm1 = nn.LayerNorm(d_model) 172 | self.norm2 = nn.LayerNorm(d_model) 173 | self.dropout1 = nn.Dropout(dropout) 174 | self.dropout2 = nn.Dropout(dropout) 175 | 176 | self.activation = _get_activation_fn(activation) 177 | self.normalize_before = normalize_before 178 | 179 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 180 | return tensor if pos is None else tensor + pos 181 | 182 | def forward_post( 183 | self, 184 | src, 185 | src_mask: Optional[Tensor] = None, 186 | src_key_padding_mask: Optional[Tensor] = None, 187 | pos: Optional[Tensor] = None, 188 | ): 189 | q = k = self.with_pos_embed(src, pos) 190 | src2 = self.self_attn( 191 | q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 192 | )[0] 193 | src = src + self.dropout1(src2) 194 | src = self.norm1(src) 195 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 196 | src = src + self.dropout2(src2) 197 | src = self.norm2(src) 198 | return src 199 | 200 | def forward_pre( 201 | self, 202 | src, 203 | src_mask: Optional[Tensor] = None, 204 | src_key_padding_mask: Optional[Tensor] = None, 205 | pos: Optional[Tensor] = None, 206 | ): 207 | src2 = self.norm1(src) 208 | q = k = self.with_pos_embed(src2, pos) 209 | src2 = self.self_attn( 210 | q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 211 | )[0] 212 | src = src + self.dropout1(src2) 213 | src2 = self.norm2(src) 214 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 215 | src = src + self.dropout2(src2) 216 | return src 217 | 218 | def forward( 219 | self, 220 | src, 221 | src_mask: Optional[Tensor] = None, 222 | src_key_padding_mask: Optional[Tensor] = None, 223 | pos: Optional[Tensor] = None, 224 | ): 225 | if self.normalize_before: 226 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 227 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 228 | 229 | 230 | class TransformerDecoderLayer(nn.Module): 231 | def __init__( 232 | self, 233 | d_model, 234 | nhead, 235 | dim_feedforward=2048, 236 | dropout=0.1, 237 | activation="relu", 238 | normalize_before=False, 239 | ): 240 | super().__init__() 241 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 242 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 243 | # Implementation of Feedforward model 244 | self.linear1 = nn.Linear(d_model, dim_feedforward) 245 | self.dropout = nn.Dropout(dropout) 246 | self.linear2 = nn.Linear(dim_feedforward, d_model) 247 | 248 | self.norm1 = nn.LayerNorm(d_model) 249 | self.norm2 = nn.LayerNorm(d_model) 250 | self.norm3 = nn.LayerNorm(d_model) 251 | self.dropout1 = nn.Dropout(dropout) 252 | self.dropout2 = nn.Dropout(dropout) 253 | self.dropout3 = nn.Dropout(dropout) 254 | 255 | self.activation = _get_activation_fn(activation) 256 | self.normalize_before = normalize_before 257 | 258 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 259 | return tensor if pos is None else tensor + pos 260 | 261 | def forward_post( 262 | self, 263 | tgt, 264 | memory, 265 | tgt_mask: Optional[Tensor] = None, 266 | memory_mask: Optional[Tensor] = None, 267 | tgt_key_padding_mask: Optional[Tensor] = None, 268 | memory_key_padding_mask: Optional[Tensor] = None, 269 | pos: Optional[Tensor] = None, 270 | query_pos: Optional[Tensor] = None, 271 | ): 272 | q = k = self.with_pos_embed(tgt, query_pos) 273 | tgt2 = self.self_attn( 274 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 275 | )[0] 276 | tgt = tgt + self.dropout1(tgt2) 277 | tgt = self.norm1(tgt) 278 | tgt2 = self.multihead_attn( 279 | query=self.with_pos_embed(tgt, query_pos), 280 | key=self.with_pos_embed(memory, pos), 281 | value=memory, 282 | attn_mask=memory_mask, 283 | key_padding_mask=memory_key_padding_mask, 284 | )[0] 285 | tgt = tgt + self.dropout2(tgt2) 286 | tgt = self.norm2(tgt) 287 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 288 | tgt = tgt + self.dropout3(tgt2) 289 | tgt = self.norm3(tgt) 290 | return tgt 291 | 292 | def forward_pre( 293 | self, 294 | tgt, 295 | memory, 296 | tgt_mask: Optional[Tensor] = None, 297 | memory_mask: Optional[Tensor] = None, 298 | tgt_key_padding_mask: Optional[Tensor] = None, 299 | memory_key_padding_mask: Optional[Tensor] = None, 300 | pos: Optional[Tensor] = None, 301 | query_pos: Optional[Tensor] = None, 302 | ): 303 | tgt2 = self.norm1(tgt) 304 | q = k = self.with_pos_embed(tgt2, query_pos) 305 | tgt2 = self.self_attn( 306 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 307 | )[0] 308 | tgt = tgt + self.dropout1(tgt2) 309 | tgt2 = self.norm2(tgt) 310 | tgt2 = self.multihead_attn( 311 | query=self.with_pos_embed(tgt2, query_pos), 312 | key=self.with_pos_embed(memory, pos), 313 | value=memory, 314 | attn_mask=memory_mask, 315 | key_padding_mask=memory_key_padding_mask, 316 | )[0] 317 | tgt = tgt + self.dropout2(tgt2) 318 | tgt2 = self.norm3(tgt) 319 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 320 | tgt = tgt + self.dropout3(tgt2) 321 | return tgt 322 | 323 | def forward( 324 | self, 325 | tgt, 326 | memory, 327 | tgt_mask: Optional[Tensor] = None, 328 | memory_mask: Optional[Tensor] = None, 329 | tgt_key_padding_mask: Optional[Tensor] = None, 330 | memory_key_padding_mask: Optional[Tensor] = None, 331 | pos: Optional[Tensor] = None, 332 | query_pos: Optional[Tensor] = None, 333 | ): 334 | if self.normalize_before: 335 | return self.forward_pre( 336 | tgt, 337 | memory, 338 | tgt_mask, 339 | memory_mask, 340 | tgt_key_padding_mask, 341 | memory_key_padding_mask, 342 | pos, 343 | query_pos, 344 | ) 345 | return self.forward_post( 346 | tgt, 347 | memory, 348 | tgt_mask, 349 | memory_mask, 350 | tgt_key_padding_mask, 351 | memory_key_padding_mask, 352 | pos, 353 | query_pos, 354 | ) 355 | 356 | 357 | def _get_clones(module, N): 358 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 359 | 360 | 361 | def _get_activation_fn(activation): 362 | """Return an activation function given a string""" 363 | if activation == "relu": 364 | return F.relu 365 | if activation == "gelu": 366 | return F.gelu 367 | if activation == "glu": 368 | return F.glu 369 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 370 | -------------------------------------------------------------------------------- /mask_former/modeling/transformer/transformer_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | import fvcore.nn.weight_init as weight_init 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from detectron2.config import configurable 9 | from detectron2.layers import Conv2d 10 | 11 | from .position_encoding import PositionEmbeddingSine 12 | from .transformer import Transformer 13 | 14 | 15 | class TransformerPredictor(nn.Module): 16 | @configurable 17 | def __init__( 18 | self, 19 | in_channels, 20 | mask_classification=True, 21 | *, 22 | num_classes: int, 23 | hidden_dim: int, 24 | num_queries: int, 25 | nheads: int, 26 | dropout: float, 27 | dim_feedforward: int, 28 | enc_layers: int, 29 | dec_layers: int, 30 | pre_norm: bool, 31 | deep_supervision: bool, 32 | mask_dim: int, 33 | enforce_input_project: bool, 34 | ): 35 | """ 36 | NOTE: this interface is experimental. 37 | Args: 38 | in_channels: channels of the input features 39 | mask_classification: whether to add mask classifier or not 40 | num_classes: number of classes 41 | hidden_dim: Transformer feature dimension 42 | num_queries: number of queries 43 | nheads: number of heads 44 | dropout: dropout in Transformer 45 | dim_feedforward: feature dimension in feedforward network 46 | enc_layers: number of Transformer encoder layers 47 | dec_layers: number of Transformer decoder layers 48 | pre_norm: whether to use pre-LayerNorm or not 49 | deep_supervision: whether to add supervision to every decoder layers 50 | mask_dim: mask feature dimension 51 | enforce_input_project: add input project 1x1 conv even if input 52 | channels and hidden dim is identical 53 | """ 54 | super().__init__() 55 | 56 | self.mask_classification = mask_classification 57 | 58 | # positional encoding 59 | N_steps = hidden_dim // 2 60 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 61 | 62 | transformer = Transformer( 63 | d_model=hidden_dim, 64 | dropout=dropout, 65 | nhead=nheads, 66 | dim_feedforward=dim_feedforward, 67 | num_encoder_layers=enc_layers, 68 | num_decoder_layers=dec_layers, 69 | normalize_before=pre_norm, 70 | return_intermediate_dec=deep_supervision, 71 | ) 72 | 73 | self.num_queries = num_queries 74 | self.transformer = transformer 75 | hidden_dim = transformer.d_model 76 | 77 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 78 | 79 | if in_channels != hidden_dim or enforce_input_project: 80 | self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) 81 | weight_init.c2_xavier_fill(self.input_proj) 82 | else: 83 | self.input_proj = nn.Sequential() 84 | self.aux_loss = deep_supervision 85 | 86 | # output FFNs 87 | if self.mask_classification: 88 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 89 | self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) 90 | 91 | @classmethod 92 | def from_config(cls, cfg, in_channels, mask_classification): 93 | ret = {} 94 | ret["in_channels"] = in_channels 95 | ret["mask_classification"] = mask_classification 96 | 97 | ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES 98 | ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM 99 | ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES 100 | # Transformer parameters: 101 | ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS 102 | ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT 103 | ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD 104 | ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS 105 | ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS 106 | ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM 107 | ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 108 | ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ 109 | 110 | ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM 111 | 112 | return ret 113 | 114 | def forward(self, x, mask_features): 115 | pos = self.pe_layer(x) 116 | 117 | src = x 118 | mask = None 119 | hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) 120 | 121 | if self.mask_classification: 122 | outputs_class = self.class_embed(hs) 123 | out = {"pred_logits": outputs_class[-1]} 124 | else: 125 | out = {} 126 | 127 | if self.aux_loss: 128 | # [l, bs, queries, embed] 129 | mask_embed = self.mask_embed(hs) 130 | outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) 131 | out["pred_masks"] = outputs_seg_masks[-1] 132 | out["aux_outputs"] = self._set_aux_loss( 133 | outputs_class if self.mask_classification else None, outputs_seg_masks 134 | ) 135 | else: 136 | # FIXME h_boxes takes the last one computed, keep this in mind 137 | # [bs, queries, embed] 138 | mask_embed = self.mask_embed(hs[-1]) 139 | outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) 140 | out["pred_masks"] = outputs_seg_masks 141 | return out 142 | 143 | @torch.jit.unused 144 | def _set_aux_loss(self, outputs_class, outputs_seg_masks): 145 | # this is a workaround to make torchscript happy, as torchscript 146 | # doesn't support dictionary with non-homogeneous values, such 147 | # as a dict having both a Tensor and a list. 148 | if self.mask_classification: 149 | return [ 150 | {"pred_logits": a, "pred_masks": b} 151 | for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) 152 | ] 153 | else: 154 | return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] 155 | 156 | 157 | class MLP(nn.Module): 158 | """Very simple multi-layer perceptron (also called FFN)""" 159 | 160 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 161 | super().__init__() 162 | self.num_layers = num_layers 163 | h = [hidden_dim] * (num_layers - 1) 164 | self.layers = nn.ModuleList( 165 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 166 | ) 167 | 168 | def forward(self, x): 169 | for i, layer in enumerate(self.layers): 170 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 171 | return x 172 | -------------------------------------------------------------------------------- /mask_former/test_time_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | from itertools import count 4 | 5 | import numpy as np 6 | import torch 7 | from fvcore.transforms import HFlipTransform 8 | from torch import nn 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from detectron2.data.detection_utils import read_image 12 | from detectron2.modeling import DatasetMapperTTA 13 | 14 | __all__ = [ 15 | "SemanticSegmentorWithTTA", 16 | ] 17 | 18 | 19 | class SemanticSegmentorWithTTA(nn.Module): 20 | """ 21 | A SemanticSegmentor with test-time augmentation enabled. 22 | Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. 23 | """ 24 | 25 | def __init__(self, cfg, model, tta_mapper=None, batch_size=1): 26 | """ 27 | Args: 28 | cfg (CfgNode): 29 | model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. 30 | tta_mapper (callable): takes a dataset dict and returns a list of 31 | augmented versions of the dataset dict. Defaults to 32 | `DatasetMapperTTA(cfg)`. 33 | batch_size (int): batch the augmented images into this batch size for inference. 34 | """ 35 | super().__init__() 36 | if isinstance(model, DistributedDataParallel): 37 | model = model.module 38 | self.cfg = cfg.clone() 39 | 40 | self.model = model 41 | 42 | if tta_mapper is None: 43 | tta_mapper = DatasetMapperTTA(cfg) 44 | self.tta_mapper = tta_mapper 45 | self.batch_size = batch_size 46 | 47 | def _batch_inference(self, batched_inputs): 48 | """ 49 | Execute inference on a list of inputs, 50 | using batch size = self.batch_size, instead of the length of the list. 51 | Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward` 52 | """ 53 | outputs = [] 54 | inputs = [] 55 | for idx, input in zip(count(), batched_inputs): 56 | inputs.append(input) 57 | if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: 58 | with torch.no_grad(): 59 | outputs.extend(self.model(inputs)) 60 | inputs = [] 61 | return outputs 62 | 63 | def __call__(self, batched_inputs): 64 | """ 65 | Same input/output format as :meth:`SemanticSegmentor.forward` 66 | """ 67 | 68 | def _maybe_read_image(dataset_dict): 69 | ret = copy.copy(dataset_dict) 70 | if "image" not in ret: 71 | image = read_image(ret.pop("file_name"), self.model.input_format) 72 | image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW 73 | ret["image"] = image 74 | if "height" not in ret and "width" not in ret: 75 | ret["height"] = image.shape[1] 76 | ret["width"] = image.shape[2] 77 | return ret 78 | 79 | return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] 80 | 81 | def _inference_one_image(self, input): 82 | """ 83 | Args: 84 | input (dict): one dataset dict with "image" field being a CHW tensor 85 | Returns: 86 | dict: one output dict 87 | """ 88 | augmented_inputs, tfms = self._get_augmented_inputs(input) 89 | # 1: forward with all augmented images 90 | outputs = self._batch_inference(augmented_inputs) 91 | # Delete now useless variables to avoid being out of memory 92 | del augmented_inputs 93 | # 2: merge the results 94 | # handle flip specially 95 | new_outputs = [] 96 | for output, tfm in zip(outputs, tfms): 97 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 98 | new_outputs.append(output.pop("sem_seg").flip(dims=[2])) 99 | else: 100 | new_outputs.append(output.pop("sem_seg")) 101 | del outputs 102 | # to avoid OOM with torch.stack 103 | final_predictions = new_outputs[0] 104 | for i in range(1, len(new_outputs)): 105 | final_predictions += new_outputs[i] 106 | final_predictions = final_predictions / len(new_outputs) 107 | del new_outputs 108 | return {"sem_seg": final_predictions} 109 | 110 | def _get_augmented_inputs(self, input): 111 | augmented_inputs = self.tta_mapper(input) 112 | tfms = [x.pop("transforms") for x in augmented_inputs] 113 | return augmented_inputs, tfms 114 | -------------------------------------------------------------------------------- /mask_former/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /mask_former/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.distributed as dist 12 | import torchvision 13 | from torch import Tensor 14 | 15 | 16 | def _max_by_axis(the_list): 17 | # type: (List[List[int]]) -> List[int] 18 | maxes = the_list[0] 19 | for sublist in the_list[1:]: 20 | for index, item in enumerate(sublist): 21 | maxes[index] = max(maxes[index], item) 22 | return maxes 23 | 24 | 25 | class NestedTensor(object): 26 | def __init__(self, tensors, mask: Optional[Tensor]): 27 | self.tensors = tensors 28 | self.mask = mask 29 | 30 | def to(self, device): 31 | # type: (Device) -> NestedTensor # noqa 32 | cast_tensor = self.tensors.to(device) 33 | mask = self.mask 34 | if mask is not None: 35 | assert mask is not None 36 | cast_mask = mask.to(device) 37 | else: 38 | cast_mask = None 39 | return NestedTensor(cast_tensor, cast_mask) 40 | 41 | def decompose(self): 42 | return self.tensors, self.mask 43 | 44 | def __repr__(self): 45 | return str(self.tensors) 46 | 47 | 48 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 49 | # TODO make this more general 50 | if tensor_list[0].ndim == 3: 51 | if torchvision._is_tracing(): 52 | # nested_tensor_from_tensor_list() does not export well to ONNX 53 | # call _onnx_nested_tensor_from_tensor_list() instead 54 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 55 | 56 | # TODO make it support different-sized images 57 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 58 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 59 | batch_shape = [len(tensor_list)] + max_size 60 | b, c, h, w = batch_shape 61 | dtype = tensor_list[0].dtype 62 | device = tensor_list[0].device 63 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 64 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 65 | for img, pad_img, m in zip(tensor_list, tensor, mask): 66 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 67 | m[: img.shape[1], : img.shape[2]] = False 68 | else: 69 | raise ValueError("not supported") 70 | return NestedTensor(tensor, mask) 71 | 72 | 73 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 74 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 75 | @torch.jit.unused 76 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 77 | max_size = [] 78 | for i in range(tensor_list[0].dim()): 79 | max_size_i = torch.max( 80 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 81 | ).to(torch.int64) 82 | max_size.append(max_size_i) 83 | max_size = tuple(max_size) 84 | 85 | # work around for 86 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 87 | # m[: img.shape[1], :img.shape[2]] = False 88 | # which is not yet supported in onnx 89 | padded_imgs = [] 90 | padded_masks = [] 91 | for img in tensor_list: 92 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 93 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 94 | padded_imgs.append(padded_img) 95 | 96 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 97 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 98 | padded_masks.append(padded_mask.to(torch.bool)) 99 | 100 | tensor = torch.stack(padded_imgs) 101 | mask = torch.stack(padded_masks) 102 | 103 | return NestedTensor(tensor, mask=mask) 104 | 105 | 106 | def is_dist_avail_and_initialized(): 107 | if not dist.is_available(): 108 | return False 109 | if not dist.is_initialized(): 110 | return False 111 | return True 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | scipy 3 | shapely 4 | timm 5 | h5py -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | This directory contains few tools for MaskFormer. 2 | 3 | * `convert-torchvision-to-d2.py` 4 | 5 | Tool to convert torchvision pre-trained weights for D2. 6 | 7 | ``` 8 | wget https://download.pytorch.org/models/resnet101-63fe2227.pth 9 | python tools/convert-torchvision-to-d2.py resnet101-63fe2227.pth R-101.pkl 10 | ``` 11 | 12 | * `convert-pretrained-swin-model-to-d2.py` 13 | 14 | Tool to convert Swin Transformer pre-trained weights for D2. 15 | 16 | ``` 17 | pip install timm 18 | 19 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 20 | python tools/convert-pretrained-swin-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl 21 | 22 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth 23 | python tools/convert-pretrained-swin-model-to-d2.py swin_small_patch4_window7_224.pth swin_small_patch4_window7_224.pkl 24 | 25 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth 26 | python tools/convert-pretrained-swin-model-to-d2.py swin_base_patch4_window12_384_22k.pth swin_base_patch4_window12_384_22k.pkl 27 | 28 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth 29 | python tools/convert-pretrained-swin-model-to-d2.py swin_large_patch4_window12_384_22k.pth swin_large_patch4_window12_384_22k.pkl 30 | ``` 31 | 32 | * `evaluate_pq_for_semantic_segmentation.py` 33 | 34 | Tool to evaluate PQ (PQ-stuff) for semantic segmentation predictions. 35 | 36 | Usage: 37 | 38 | ``` 39 | python tools/evaluate_pq_for_semantic_segmentation.py --dataset-name ade20k_sem_seg_val --json-file OUTPUT_DIR/inference/sem_seg_predictions.json 40 | ``` 41 | 42 | where `OUTPUT_DIR` is set in the config file. 43 | 44 | * `analyze_model.py` 45 | 46 | Tool to analyze model parameters and flops. 47 | 48 | Usage for semantic segmentation: 49 | 50 | ``` 51 | python tools/analyze_model.py --num-inputs 1 --tasks flop --use-fixed-input-size --config-file CONFIG_FILE 52 | ``` 53 | 54 | Note that, for semantic segmentation, we use a dummy image with fixed size that equals to `cfg.INPUT.CROP.SIZE[0] x cfg.INPUT.CROP.SIZE[0]`. 55 | 56 | Usage for panoptic segmentation: 57 | 58 | ``` 59 | python tools/analyze_model.py --num-inputs 100 --tasks flop --config-file CONFIG_FILE 60 | ``` 61 | 62 | Note that, for panoptic segmentation, we compute the average flops over 100 real validation images. 63 | -------------------------------------------------------------------------------- /tools/analyze_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detectron2/blob/main/tools/analyze_model.py 4 | 5 | import logging 6 | import numpy as np 7 | from collections import Counter 8 | import tqdm 9 | from fvcore.nn import flop_count_table # can also try flop_count_str 10 | 11 | from detectron2.checkpoint import DetectionCheckpointer 12 | from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate 13 | from detectron2.data import build_detection_test_loader 14 | from detectron2.engine import default_argument_parser 15 | from detectron2.modeling import build_model 16 | from detectron2.projects.deeplab import add_deeplab_config 17 | from detectron2.utils.analysis import ( 18 | FlopCountAnalysis, 19 | activation_count_operators, 20 | parameter_count_table, 21 | ) 22 | from detectron2.utils.logger import setup_logger 23 | 24 | # fmt: off 25 | import os 26 | import sys 27 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 28 | # fmt: on 29 | 30 | from mask_former import add_mask_former_config 31 | 32 | logger = logging.getLogger("detectron2") 33 | 34 | 35 | def setup(args): 36 | if args.config_file.endswith(".yaml"): 37 | cfg = get_cfg() 38 | add_deeplab_config(cfg) 39 | add_mask_former_config(cfg) 40 | cfg.merge_from_file(args.config_file) 41 | cfg.DATALOADER.NUM_WORKERS = 0 42 | cfg.merge_from_list(args.opts) 43 | cfg.freeze() 44 | else: 45 | cfg = LazyConfig.load(args.config_file) 46 | cfg = LazyConfig.apply_overrides(cfg, args.opts) 47 | setup_logger(name="fvcore") 48 | setup_logger() 49 | return cfg 50 | 51 | 52 | def do_flop(cfg): 53 | if isinstance(cfg, CfgNode): 54 | data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) 55 | model = build_model(cfg) 56 | DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) 57 | else: 58 | data_loader = instantiate(cfg.dataloader.test) 59 | model = instantiate(cfg.model) 60 | model.to(cfg.train.device) 61 | DetectionCheckpointer(model).load(cfg.train.init_checkpoint) 62 | model.eval() 63 | 64 | counts = Counter() 65 | total_flops = [] 66 | for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa 67 | if args.use_fixed_input_size and isinstance(cfg, CfgNode): 68 | import torch 69 | crop_size = cfg.INPUT.CROP.SIZE[0] 70 | data[0]["image"] = torch.zeros((3, crop_size, crop_size)) 71 | flops = FlopCountAnalysis(model, data) 72 | if idx > 0: 73 | flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False) 74 | counts += flops.by_operator() 75 | total_flops.append(flops.total()) 76 | 77 | logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) 78 | logger.info( 79 | "Average GFlops for each type of operators:\n" 80 | + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) 81 | ) 82 | logger.info( 83 | "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9) 84 | ) 85 | 86 | 87 | def do_activation(cfg): 88 | if isinstance(cfg, CfgNode): 89 | data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) 90 | model = build_model(cfg) 91 | DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) 92 | else: 93 | data_loader = instantiate(cfg.dataloader.test) 94 | model = instantiate(cfg.model) 95 | model.to(cfg.train.device) 96 | DetectionCheckpointer(model).load(cfg.train.init_checkpoint) 97 | model.eval() 98 | 99 | counts = Counter() 100 | total_activations = [] 101 | for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa 102 | count = activation_count_operators(model, data) 103 | counts += count 104 | total_activations.append(sum(count.values())) 105 | logger.info( 106 | "(Million) Activations for Each Type of Operators:\n" 107 | + str([(k, v / idx) for k, v in counts.items()]) 108 | ) 109 | logger.info( 110 | "Total (Million) Activations: {}±{}".format( 111 | np.mean(total_activations), np.std(total_activations) 112 | ) 113 | ) 114 | 115 | 116 | def do_parameter(cfg): 117 | if isinstance(cfg, CfgNode): 118 | model = build_model(cfg) 119 | else: 120 | model = instantiate(cfg.model) 121 | logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5)) 122 | 123 | 124 | def do_structure(cfg): 125 | if isinstance(cfg, CfgNode): 126 | model = build_model(cfg) 127 | else: 128 | model = instantiate(cfg.model) 129 | logger.info("Model Structure:\n" + str(model)) 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = default_argument_parser( 134 | epilog=""" 135 | Examples: 136 | To show parameters of a model: 137 | $ ./analyze_model.py --tasks parameter \\ 138 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml 139 | Flops and activations are data-dependent, therefore inputs and model weights 140 | are needed to count them: 141 | $ ./analyze_model.py --num-inputs 100 --tasks flop \\ 142 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\ 143 | MODEL.WEIGHTS /path/to/model.pkl 144 | """ 145 | ) 146 | parser.add_argument( 147 | "--tasks", 148 | choices=["flop", "activation", "parameter", "structure"], 149 | required=True, 150 | nargs="+", 151 | ) 152 | parser.add_argument( 153 | "-n", 154 | "--num-inputs", 155 | default=100, 156 | type=int, 157 | help="number of inputs used to compute statistics for flops/activations, " 158 | "both are data dependent.", 159 | ) 160 | parser.add_argument( 161 | "--use-fixed-input-size", 162 | action="store_true", 163 | help="use fixed input size when calculating flops", 164 | ) 165 | args = parser.parse_args() 166 | assert not args.eval_only 167 | assert args.num_gpus == 1 168 | 169 | cfg = setup(args) 170 | 171 | for task in args.tasks: 172 | { 173 | "flop": do_flop, 174 | "activation": do_activation, 175 | "parameter": do_parameter, 176 | "structure": do_structure, 177 | }[task](cfg) 178 | -------------------------------------------------------------------------------- /tools/convert-pretrained-swin-model-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download pretrained swin model: 12 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 13 | # run the conversion 14 | ./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl 15 | # Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl" 18 | INPUT: 19 | FORMAT: "RGB" 20 | """ 21 | 22 | if __name__ == "__main__": 23 | input = sys.argv[1] 24 | 25 | obj = torch.load(input, map_location="cpu")["model"] 26 | 27 | res = {"model": obj, "__author__": "third_party", "matching_heuristics": True} 28 | 29 | with open(sys.argv[2], "wb") as f: 30 | pkl.dump(res, f) 31 | -------------------------------------------------------------------------------- /tools/convert-torchvision-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download one of the ResNet{18,34,50,101,152} models from torchvision: 12 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth 13 | # run the conversion 14 | ./convert-torchvision-to-d2.py r50.pth r50.pkl 15 | # Then, use r50.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/r50.pkl" 18 | PIXEL_MEAN: [123.675, 116.280, 103.530] 19 | PIXEL_STD: [58.395, 57.120, 57.375] 20 | RESNETS: 21 | DEPTH: 50 22 | STRIDE_IN_1X1: False 23 | INPUT: 24 | FORMAT: "RGB" 25 | """ 26 | 27 | if __name__ == "__main__": 28 | input = sys.argv[1] 29 | 30 | obj = torch.load(input, map_location="cpu") 31 | 32 | newmodel = {} 33 | for k in list(obj.keys()): 34 | old_k = k 35 | if "layer" not in k: 36 | k = "stem." + k 37 | for t in [1, 2, 3, 4]: 38 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 39 | for t in [1, 2, 3]: 40 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 41 | k = k.replace("downsample.0", "shortcut") 42 | k = k.replace("downsample.1", "shortcut.norm") 43 | print(old_k, "->", k) 44 | newmodel[k] = obj.pop(old_k).detach().numpy() 45 | 46 | res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} 47 | 48 | with open(sys.argv[2], "wb") as f: 49 | pkl.dump(res, f) 50 | if obj: 51 | print("Unconverted keys:", obj.keys()) 52 | -------------------------------------------------------------------------------- /tools/evaluate_pq_for_semantic_segmentation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import argparse 5 | import json 6 | import os 7 | from collections import defaultdict 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from detectron2.data import MetadataCatalog 14 | from detectron2.data.detection_utils import read_image 15 | from detectron2.utils.file_io import PathManager 16 | from pycocotools import mask as maskUtils 17 | 18 | from panopticapi.evaluation import PQStat 19 | 20 | 21 | def default_argument_parser(): 22 | """ 23 | Creates a parser with some common arguments used by analysis tools. 24 | Returns: 25 | argparse.ArgumentParser: 26 | """ 27 | parser = argparse.ArgumentParser(description="Evaluate PQ metric for semantic segmentation.") 28 | # NOTE: currently does not support Cityscapes, you need to convert 29 | # Cityscapes prediction format to Detectron2 prediction format. 30 | parser.add_argument( 31 | "--dataset-name", 32 | default="ade20k_sem_seg_val", 33 | choices=["ade20k_sem_seg_val", "coco_2017_test_stuff_10k_sem_seg", "ade20k_full_sem_seg_val"], 34 | help="dataset name you want to evaluate") 35 | parser.add_argument("--json-file", default="", help="path to detection json file") 36 | 37 | return parser 38 | 39 | 40 | # Modified from the official panoptic api: https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py 41 | def pq_compute_single_image(segm_gt, segm_dt, categories, ignore_label): 42 | pq_stat = PQStat() 43 | VOID = ignore_label 44 | OFFSET = 256 * 256 * 256 45 | 46 | pan_gt = segm_gt 47 | pan_pred = segm_dt 48 | 49 | gt_ann = {'segments_info': []} 50 | labels, labels_cnt = np.unique(segm_gt, return_counts=True) 51 | for cat_id, cnt in zip(labels, labels_cnt): 52 | if cat_id == VOID: 53 | continue 54 | gt_ann['segments_info'].append( 55 | {"id": cat_id, "category_id": cat_id, "area": cnt, "iscrowd": 0} 56 | ) 57 | 58 | pred_ann = {'segments_info': []} 59 | for cat_id in np.unique(segm_dt): 60 | pred_ann['segments_info'].append({"id": cat_id, "category_id": cat_id}) 61 | 62 | gt_segms = {el['id']: el for el in gt_ann['segments_info']} 63 | pred_segms = {el['id']: el for el in pred_ann['segments_info']} 64 | 65 | # predicted segments area calculation + prediction sanity checks 66 | pred_labels_set = set(el['id'] for el in pred_ann['segments_info']) 67 | labels, labels_cnt = np.unique(pan_pred, return_counts=True) 68 | for label, label_cnt in zip(labels, labels_cnt): 69 | if label not in pred_segms: 70 | if label == VOID: 71 | continue 72 | raise KeyError('In the image with ID {} segment with ID {} is presented in PNG and not presented in JSON.'.format(image_id, label)) 73 | pred_segms[label]['area'] = label_cnt 74 | pred_labels_set.remove(label) 75 | if pred_segms[label]['category_id'] not in categories: 76 | raise KeyError('In the image with ID {} segment with ID {} has unknown category_id {}.'.format(image_id, label, pred_segms[label]['category_id'])) 77 | if len(pred_labels_set) != 0: 78 | raise KeyError('In the image with ID {} the following segment IDs {} are presented in JSON and not presented in PNG.'.format(image_id, list(pred_labels_set))) 79 | 80 | # confusion matrix calculation 81 | pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(np.uint64) 82 | gt_pred_map = {} 83 | labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) 84 | for label, intersection in zip(labels, labels_cnt): 85 | gt_id = label // OFFSET 86 | pred_id = label % OFFSET 87 | gt_pred_map[(gt_id, pred_id)] = intersection 88 | 89 | # count all matched pairs 90 | gt_matched = set() 91 | pred_matched = set() 92 | for label_tuple, intersection in gt_pred_map.items(): 93 | gt_label, pred_label = label_tuple 94 | if gt_label not in gt_segms: 95 | continue 96 | if pred_label not in pred_segms: 97 | continue 98 | if gt_segms[gt_label]['iscrowd'] == 1: 99 | continue 100 | if gt_segms[gt_label]['category_id'] != pred_segms[pred_label]['category_id']: 101 | continue 102 | 103 | union = pred_segms[pred_label]['area'] + gt_segms[gt_label]['area'] - intersection - gt_pred_map.get((VOID, pred_label), 0) 104 | iou = intersection / union 105 | if iou > 0.5: 106 | pq_stat[gt_segms[gt_label]['category_id']].tp += 1 107 | pq_stat[gt_segms[gt_label]['category_id']].iou += iou 108 | gt_matched.add(gt_label) 109 | pred_matched.add(pred_label) 110 | 111 | # count false positives 112 | crowd_labels_dict = {} 113 | for gt_label, gt_info in gt_segms.items(): 114 | if gt_label in gt_matched: 115 | continue 116 | # crowd segments are ignored 117 | if gt_info['iscrowd'] == 1: 118 | crowd_labels_dict[gt_info['category_id']] = gt_label 119 | continue 120 | pq_stat[gt_info['category_id']].fn += 1 121 | 122 | # count false positives 123 | for pred_label, pred_info in pred_segms.items(): 124 | if pred_label in pred_matched: 125 | continue 126 | # intersection of the segment with VOID 127 | intersection = gt_pred_map.get((VOID, pred_label), 0) 128 | # plus intersection with corresponding CROWD region if it exists 129 | if pred_info['category_id'] in crowd_labels_dict: 130 | intersection += gt_pred_map.get((crowd_labels_dict[pred_info['category_id']], pred_label), 0) 131 | # predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions 132 | if intersection / pred_info['area'] > 0.5: 133 | continue 134 | pq_stat[pred_info['category_id']].fp += 1 135 | 136 | return pq_stat 137 | 138 | 139 | def main(): 140 | parser = default_argument_parser() 141 | args = parser.parse_args() 142 | 143 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 144 | json_file = args.json_file 145 | 146 | with open(json_file) as f: 147 | predictions = json.load(f) 148 | 149 | imgToAnns = defaultdict(list) 150 | for pred in predictions: 151 | image_id = os.path.basename(pred["file_name"]).split(".")[0] 152 | imgToAnns[image_id].append( 153 | {"category_id" : pred["category_id"], "segmentation" : pred["segmentation"]} 154 | ) 155 | 156 | image_ids = list(imgToAnns.keys()) 157 | 158 | meta = MetadataCatalog.get(args.dataset_name) 159 | class_names = meta.stuff_classes 160 | num_classes = len(meta.stuff_classes) 161 | ignore_label = meta.ignore_label 162 | conf_matrix = np.zeros((num_classes + 1, num_classes + 1), dtype=np.int64) 163 | 164 | categories = {} 165 | for i in range(num_classes): 166 | categories[i] = {"id": i, "name": class_names[i], "isthing": 0} 167 | 168 | pq_stat = PQStat() 169 | 170 | for image_id in tqdm(image_ids): 171 | if args.dataset_name == "ade20k_sem_seg_val": 172 | gt_dir = os.path.join(_root, "ADEChallengeData2016", "annotations_detectron2", "validation") 173 | segm_gt = read_image(os.path.join(gt_dir, image_id + ".png")).copy().astype(np.int64) 174 | elif args.dataset_name == "coco_2017_test_stuff_10k_sem_seg": 175 | gt_dir = os.path.join(_root, "coco", "coco_stuff_10k", "annotations_detectron2", "test") 176 | segm_gt = read_image(os.path.join(gt_dir, image_id + ".png")).copy().astype(np.int64) 177 | elif args.dataset_name == "ade20k_full_sem_seg_val": 178 | gt_dir = os.path.join(_root, "ADE20K_2021_17_01", "annotations_detectron2", "validation") 179 | segm_gt = read_image(os.path.join(gt_dir, image_id + ".tif")).copy().astype(np.int64) 180 | else: 181 | raise ValueError(f"Unsupported dataset {args.dataset_name}") 182 | 183 | # get predictions 184 | segm_dt = np.zeros_like(segm_gt) 185 | anns = imgToAnns[image_id] 186 | for ann in anns: 187 | # map back category_id 188 | if hasattr(meta, "stuff_dataset_id_to_contiguous_id"): 189 | if ann["category_id"] in meta.stuff_dataset_id_to_contiguous_id: 190 | category_id = meta.stuff_dataset_id_to_contiguous_id[ann["category_id"]] 191 | else: 192 | category_id = ann["category_id"] 193 | mask = maskUtils.decode(ann["segmentation"]) 194 | segm_dt[mask > 0] = category_id 195 | 196 | # miou 197 | gt = segm_gt.copy() 198 | pred = segm_dt.copy() 199 | gt[gt == ignore_label] = num_classes 200 | conf_matrix += np.bincount( 201 | (num_classes + 1) * pred.reshape(-1) + gt.reshape(-1), 202 | minlength=conf_matrix.size, 203 | ).reshape(conf_matrix.shape) 204 | 205 | # pq 206 | pq_stat_single = pq_compute_single_image(segm_gt, segm_dt, categories, meta.ignore_label) 207 | pq_stat += pq_stat_single 208 | 209 | metrics = [("All", None), ("Stuff", False)] 210 | results = {} 211 | for name, isthing in metrics: 212 | results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing) 213 | if name == 'All': 214 | results['per_class'] = per_class_results 215 | print("{:10s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N")) 216 | print("-" * (10 + 7 * 4)) 217 | 218 | for name, _isthing in metrics: 219 | print("{:10s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format( 220 | name, 221 | 100 * results[name]['pq'], 222 | 100 * results[name]['sq'], 223 | 100 * results[name]['rq'], 224 | results[name]['n']) 225 | ) 226 | 227 | # calculate miou 228 | acc = np.full(num_classes, np.nan, dtype=np.float64) 229 | iou = np.full(num_classes, np.nan, dtype=np.float64) 230 | tp = conf_matrix.diagonal()[:-1].astype(np.float64) 231 | pos_gt = np.sum(conf_matrix[:-1, :-1], axis=0).astype(np.float64) 232 | pos_pred = np.sum(conf_matrix[:-1, :-1], axis=1).astype(np.float64) 233 | acc_valid = pos_gt > 0 234 | acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid] 235 | iou_valid = (pos_gt + pos_pred) > 0 236 | union = pos_gt + pos_pred - tp 237 | iou[acc_valid] = tp[acc_valid] / union[acc_valid] 238 | miou = np.sum(iou[acc_valid]) / np.sum(iou_valid) 239 | 240 | print("") 241 | print(f"mIoU: {miou}") 242 | 243 | 244 | if __name__ == '__main__': 245 | main() 246 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MaskFormer Training Script. 4 | 5 | This script is a simplified version of the training script in detectron2/tools. 6 | """ 7 | import copy 8 | import itertools 9 | import logging 10 | import os 11 | from collections import OrderedDict 12 | from typing import Any, Dict, List, Set 13 | 14 | import torch 15 | 16 | import detectron2.utils.comm as comm 17 | from detectron2.checkpoint import DetectionCheckpointer 18 | from detectron2.config import get_cfg 19 | from detectron2.data import MetadataCatalog, build_detection_train_loader 20 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch 21 | from detectron2.evaluation import ( 22 | CityscapesInstanceEvaluator, 23 | CityscapesSemSegEvaluator, 24 | COCOEvaluator, 25 | COCOPanopticEvaluator, 26 | DatasetEvaluators, 27 | SemSegEvaluator, 28 | verify_results, 29 | ) 30 | from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler 31 | from detectron2.solver.build import maybe_add_gradient_clipping 32 | from detectron2.utils.logger import setup_logger 33 | 34 | # MaskFormer 35 | from mask_former import ( 36 | DETRPanopticDatasetMapper, 37 | MaskFormerPanopticDatasetMapper, 38 | MaskFormerSemanticDatasetMapper, 39 | SemanticSegmentorWithTTA, 40 | add_mask_former_config, 41 | ) 42 | 43 | 44 | class Trainer(DefaultTrainer): 45 | """ 46 | Extension of the Trainer class adapted to DETR. 47 | """ 48 | 49 | @classmethod 50 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 51 | """ 52 | Create evaluator(s) for a given dataset. 53 | This uses the special metadata "evaluator_type" associated with each 54 | builtin dataset. For your own dataset, you can simply create an 55 | evaluator manually in your script and do not have to worry about the 56 | hacky if-else logic here. 57 | """ 58 | if output_folder is None: 59 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 60 | evaluator_list = [] 61 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 62 | if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]: 63 | evaluator_list.append( 64 | SemSegEvaluator( 65 | dataset_name, 66 | distributed=True, 67 | output_dir=output_folder, 68 | ) 69 | ) 70 | if evaluator_type == "coco": 71 | evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) 72 | if evaluator_type in [ 73 | "coco_panoptic_seg", 74 | "ade20k_panoptic_seg", 75 | "cityscapes_panoptic_seg", 76 | ]: 77 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 78 | if evaluator_type == "cityscapes_instance": 79 | assert ( 80 | torch.cuda.device_count() >= comm.get_rank() 81 | ), "CityscapesEvaluator currently do not work with multiple machines." 82 | return CityscapesInstanceEvaluator(dataset_name) 83 | if evaluator_type == "cityscapes_sem_seg": 84 | assert ( 85 | torch.cuda.device_count() >= comm.get_rank() 86 | ), "CityscapesEvaluator currently do not work with multiple machines." 87 | return CityscapesSemSegEvaluator(dataset_name) 88 | if evaluator_type == "cityscapes_panoptic_seg": 89 | assert ( 90 | torch.cuda.device_count() >= comm.get_rank() 91 | ), "CityscapesEvaluator currently do not work with multiple machines." 92 | evaluator_list.append(CityscapesSemSegEvaluator(dataset_name)) 93 | if len(evaluator_list) == 0: 94 | raise NotImplementedError( 95 | "no Evaluator for the dataset {} with the type {}".format( 96 | dataset_name, evaluator_type 97 | ) 98 | ) 99 | elif len(evaluator_list) == 1: 100 | return evaluator_list[0] 101 | return DatasetEvaluators(evaluator_list) 102 | 103 | @classmethod 104 | def build_train_loader(cls, cfg): 105 | # Semantic segmentation dataset mapper 106 | if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic": 107 | mapper = MaskFormerSemanticDatasetMapper(cfg, True) 108 | # Panoptic segmentation dataset mapper 109 | elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic": 110 | mapper = MaskFormerPanopticDatasetMapper(cfg, True) 111 | # DETR-style dataset mapper for COCO panoptic segmentation 112 | elif cfg.INPUT.DATASET_MAPPER_NAME == "detr_panoptic": 113 | mapper = DETRPanopticDatasetMapper(cfg, True) 114 | else: 115 | mapper = None 116 | return build_detection_train_loader(cfg, mapper=mapper) 117 | 118 | @classmethod 119 | def build_lr_scheduler(cls, cfg, optimizer): 120 | """ 121 | It now calls :func:`detectron2.solver.build_lr_scheduler`. 122 | Overwrite it if you'd like a different scheduler. 123 | """ 124 | return build_lr_scheduler(cfg, optimizer) 125 | 126 | @classmethod 127 | def build_optimizer(cls, cfg, model): 128 | weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM 129 | weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED 130 | 131 | defaults = {} 132 | defaults["lr"] = cfg.SOLVER.BASE_LR 133 | defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY 134 | 135 | norm_module_types = ( 136 | torch.nn.BatchNorm1d, 137 | torch.nn.BatchNorm2d, 138 | torch.nn.BatchNorm3d, 139 | torch.nn.SyncBatchNorm, 140 | # NaiveSyncBatchNorm inherits from BatchNorm2d 141 | torch.nn.GroupNorm, 142 | torch.nn.InstanceNorm1d, 143 | torch.nn.InstanceNorm2d, 144 | torch.nn.InstanceNorm3d, 145 | torch.nn.LayerNorm, 146 | torch.nn.LocalResponseNorm, 147 | ) 148 | 149 | params: List[Dict[str, Any]] = [] 150 | memo: Set[torch.nn.parameter.Parameter] = set() 151 | for module_name, module in model.named_modules(): 152 | for module_param_name, value in module.named_parameters(recurse=False): 153 | if not value.requires_grad: 154 | continue 155 | # Avoid duplicating parameters 156 | if value in memo: 157 | continue 158 | memo.add(value) 159 | 160 | hyperparams = copy.copy(defaults) 161 | if "backbone" in module_name: 162 | hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER 163 | if ( 164 | "relative_position_bias_table" in module_param_name 165 | or "absolute_pos_embed" in module_param_name 166 | ): 167 | print(module_param_name) 168 | hyperparams["weight_decay"] = 0.0 169 | if isinstance(module, norm_module_types): 170 | hyperparams["weight_decay"] = weight_decay_norm 171 | if isinstance(module, torch.nn.Embedding): 172 | hyperparams["weight_decay"] = weight_decay_embed 173 | params.append({"params": [value], **hyperparams}) 174 | 175 | def maybe_add_full_model_gradient_clipping(optim): 176 | # detectron2 doesn't have full model gradient clipping now 177 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 178 | enable = ( 179 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 180 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 181 | and clip_norm_val > 0.0 182 | ) 183 | 184 | class FullModelGradientClippingOptimizer(optim): 185 | def step(self, closure=None): 186 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 187 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 188 | super().step(closure=closure) 189 | 190 | return FullModelGradientClippingOptimizer if enable else optim 191 | 192 | optimizer_type = cfg.SOLVER.OPTIMIZER 193 | if optimizer_type == "SGD": 194 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 195 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 196 | ) 197 | elif optimizer_type == "ADAMW": 198 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 199 | params, cfg.SOLVER.BASE_LR 200 | ) 201 | else: 202 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 203 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 204 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 205 | return optimizer 206 | 207 | @classmethod 208 | def test_with_TTA(cls, cfg, model): 209 | logger = logging.getLogger("detectron2.trainer") 210 | # In the end of training, run an evaluation with TTA. 211 | logger.info("Running inference with test-time augmentation ...") 212 | model = SemanticSegmentorWithTTA(cfg, model) 213 | evaluators = [ 214 | cls.build_evaluator( 215 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 216 | ) 217 | for name in cfg.DATASETS.TEST 218 | ] 219 | res = cls.test(cfg, model, evaluators) 220 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 221 | return res 222 | 223 | 224 | def setup(args): 225 | """ 226 | Create configs and perform basic setups. 227 | """ 228 | cfg = get_cfg() 229 | # for poly lr schedule 230 | add_deeplab_config(cfg) 231 | add_mask_former_config(cfg) 232 | cfg.merge_from_file(args.config_file) 233 | cfg.merge_from_list(args.opts) 234 | cfg.freeze() 235 | default_setup(cfg, args) 236 | # Setup logger for "mask_former" module 237 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="mask_former") 238 | return cfg 239 | 240 | 241 | def main(args): 242 | cfg = setup(args) 243 | 244 | if args.eval_only: 245 | model = Trainer.build_model(cfg) 246 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 247 | cfg.MODEL.WEIGHTS, resume=args.resume 248 | ) 249 | res = Trainer.test(cfg, model) 250 | if cfg.TEST.AUG.ENABLED: 251 | res.update(Trainer.test_with_TTA(cfg, model)) 252 | if comm.is_main_process(): 253 | verify_results(cfg, res) 254 | return res 255 | 256 | trainer = Trainer(cfg) 257 | trainer.resume_or_load(resume=args.resume) 258 | return trainer.train() 259 | 260 | 261 | if __name__ == "__main__": 262 | args = default_argument_parser().parse_args() 263 | print("Command Line Args:", args) 264 | launch( 265 | main, 266 | args.num_gpus, 267 | num_machines=args.num_machines, 268 | machine_rank=args.machine_rank, 269 | dist_url=args.dist_url, 270 | args=(args,), 271 | ) 272 | --------------------------------------------------------------------------------