├── .gitignore ├── README.md ├── configs ├── finetune.yaml └── pretrain.yaml ├── demo_images └── demo.jpg ├── evaluation └── icdar2015 │ ├── e2e │ ├── prepare_results.py │ ├── rrc_evaluation_funcs.py │ ├── script.py │ └── weighted_editdistance.py │ └── gt.zip ├── maskrcnn_benchmark ├── config │ ├── __init__.py │ ├── defaults.py │ └── paths_catalog.py ├── csrc │ ├── ROIAlign.h │ ├── ROIPool.h │ ├── cpu │ │ ├── ROIAlign_cpu.cpp │ │ ├── nms_cpu.cpp │ │ └── vision.h │ ├── cuda │ │ ├── ROIAlign_cuda.cu │ │ ├── ROIPool_cuda.cu │ │ ├── nms.cu │ │ └── vision.h │ ├── nms.h │ └── vision.cpp ├── data │ ├── __init__.py │ ├── build.py │ ├── collate_batch.py │ ├── datasets │ │ ├── __init__.py │ │ ├── coco.py │ │ ├── concat_dataset.py │ │ ├── icdar.py │ │ ├── list_dataset.py │ │ ├── scut.py │ │ ├── synthtext.py │ │ └── total_text.py │ ├── samplers │ │ ├── __init__.py │ │ ├── distributed.py │ │ ├── grouped_batch_sampler.py │ │ └── iteration_based_batch_sampler.py │ └── transforms │ │ ├── __init__.py │ │ ├── build.py │ │ └── transforms.py ├── engine │ ├── inference.py │ ├── text_inference.py │ └── trainer.py ├── layers │ ├── __init__.py │ ├── _utils.py │ ├── batch_norm.py │ ├── misc.py │ ├── nms.py │ ├── roi_align.py │ ├── roi_pool.py │ └── smooth_l1_loss.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── backbone.py │ │ ├── fpn.py │ │ └── resnet.py │ ├── balanced_positive_negative_sampler.py │ ├── box_coder.py │ ├── detector │ │ ├── __init__.py │ │ ├── detectors.py │ │ └── generalized_rcnn.py │ ├── matcher.py │ ├── poolers.py │ ├── roi_heads │ │ ├── __init__.py │ │ ├── box_head │ │ │ ├── __init__.py │ │ │ ├── box_head.py │ │ │ ├── inference.py │ │ │ ├── loss.py │ │ │ ├── roi_box_feature_extractors.py │ │ │ └── roi_box_predictors.py │ │ ├── mask_head │ │ │ ├── __init__.py │ │ │ ├── inference.py │ │ │ ├── loss.py │ │ │ ├── mask_head.py │ │ │ ├── roi_mask_feature_extractors.py │ │ │ ├── roi_mask_predictors.py │ │ │ └── roi_seq_predictors.py │ │ └── roi_heads.py │ ├── rpn │ │ ├── __init__.py │ │ ├── anchor_generator.py │ │ ├── inference.py │ │ ├── loss.py │ │ └── rpn.py │ └── utils.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_scheduler.py ├── structures │ ├── __init__.py │ ├── bounding_box.py │ ├── boxlist_ops.py │ ├── image_list.py │ └── segmentation_mask.py └── utils │ ├── README.md │ ├── __init__.py │ ├── c2_model_loading.py │ ├── chars.py │ ├── checkpoint.py │ ├── collect_env.py │ ├── comm.py │ ├── env.py │ ├── imports.py │ ├── logging.py │ ├── metric_logger.py │ ├── miscellaneous.py │ ├── model_serialization.py │ └── model_zoo.py ├── setup.py ├── test.sh ├── tests ├── checkpoint.py └── test_data_samplers.py ├── tools ├── demo.py ├── test_net.py └── train_net.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # compilation and distribution 2 | __pycache__ 3 | _ext 4 | *.pyc 5 | *.so 6 | maskrcnn_benchmark.egg-info/ 7 | build/ 8 | dist/ 9 | 10 | # pytorch/python/numpy formats 11 | *.pth 12 | *.pkl 13 | *.npy 14 | 15 | # ipython/jupyter notebooks 16 | *.ipynb 17 | 18 | # Editor temporaries 19 | *.swn 20 | *.swo 21 | *.swp 22 | *~ 23 | 24 | # Pycharm editor settings 25 | .idea 26 | 27 | # project dirs 28 | /datasets 29 | /models 30 | /outputs 31 | /vis 32 | *.jpg 33 | *.JPG 34 | *.JPEG 35 | *.zip 36 | *.txt 37 | 38 | evaluation 39 | 40 | 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MaskTextSpotter 2 | This is the code of "Mask TextSpotter: An End-to-End Trainable Neural Network for Spotting Text with Arbitrary Shapes" (TPAMI version). 3 | It is an extension of the ECCV version while sharing the same title. For more details, please refer to our [TPAMI paper](https://ieeexplore.ieee.org/document/8812908). 4 | 5 | This repo is inherited from [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) and follows the same license. 6 | 7 | ## ToDo List 8 | 9 | - [x] Release code 10 | - [x] Document for Installation 11 | - [x] Trained models 12 | - [x] Document for testing 13 | - [x] Document for training 14 | - [x] Demo script 15 | - [x] Evaluation 16 | - [ ] Release the standalone recognition model 17 | 18 | ## Installation 19 | 20 | ### Requirements: 21 | - Python3 (Python3.7 is recommended) 22 | - PyTorch >= 1.0 (1.2 is recommended) 23 | - torchvision from master 24 | - cocoapi 25 | - yacs 26 | - matplotlib 27 | - GCC >= 4.9 (This is very important!) 28 | - OpenCV 29 | - CUDA >= 9.0 (10.0 is recommended) 30 | 31 | 32 | ```bash 33 | # first, make sure that your conda is setup properly with the right environment 34 | # for that, check that `which conda`, `which pip` and `which python` points to the 35 | # right path. From a clean conda env, this is what you need to do 36 | 37 | conda create --name masktextspotter -y 38 | conda activate masktextspotter 39 | 40 | # this installs the right pip and dependencies for the fresh python 41 | conda install ipython pip 42 | 43 | # python dependencies 44 | pip install ninja yacs cython matplotlib tqdm opencv-python shapely scipy tensorboardX 45 | 46 | # install PyTorch 47 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 48 | 49 | export INSTALL_DIR=$PWD 50 | 51 | # install pycocotools 52 | cd $INSTALL_DIR 53 | git clone https://github.com/cocodataset/cocoapi.git 54 | cd cocoapi/PythonAPI 55 | python setup.py build_ext install 56 | 57 | # install apex (optional) 58 | cd $INSTALL_DIR 59 | git clone https://github.com/NVIDIA/apex.git 60 | cd apex 61 | python setup.py install --cuda_ext --cpp_ext 62 | 63 | # clone repo 64 | cd $INSTALL_DIR 65 | git clone https://github.com/MhLiao/MaskTextSpotter.git 66 | cd MaskTextSpotter 67 | 68 | # build 69 | python setup.py build develop 70 | 71 | 72 | unset INSTALL_DIR 73 | ``` 74 | 75 | ## Models 76 | Download Trained [model](https://drive.google.com/open?id=1pPRS7qS_K1keXjSye0kksqhvoyD0SARz) 77 | 78 | ## Demo 79 | You can run a demo script for a single image inference by ```python tools/demo.py```. 80 | 81 | ## Datasets 82 | Download the ICDAR2013([Google Drive](https://drive.google.com/open?id=1sptDnAomQHFVZbjvnWt2uBvyeJ-gEl-A), [BaiduYun](https://pan.baidu.com/s/18W2aFe_qOH8YQUDg4OMZdw)) and ICDAR2015([Google Drive](https://drive.google.com/open?id=1HZ4Pbx6TM9cXO3gDyV04A4Gn9fTf2b5X), [BaiduYun](https://pan.baidu.com/s/16GzPPzC5kXpdgOB_76A3cA)) as examples. 83 | 84 | The SCUT dataset used for training can be downloaded [here](https://drive.google.com/file/d/1BpE2GEFF7Ay7jPqgaeHxMmlXvM-1Es5_/view?usp=sharing). 85 | 86 | The converted labels of Total-Text dataset can be downloaded [here](https://1drv.ms/u/s!ArsnjfK83FbXgcpti8Zq9jSzhoQrqw?e=99fukk). 87 | 88 | The converted labels of SynthText can be downloaded [here](https://1drv.ms/u/s!ArsnjfK83FbXgb5vgOOVPYywgCWuQw?e=UPuNTa). 89 | 90 | The root of the dataset directory should be ```MaskTextSpotter/datasets/```. 91 | 92 | ## Testing 93 | ### Prepar dataset 94 | An example of the path of test images: ```MaskTextSpotter/datasets/icdar2015/test_iamges``` 95 | 96 | ### Check the config file (configs/finetune.yaml) for some parameters. 97 | test dataset: ```TEST.DATASETS```; 98 | 99 | input size: ```INPUT.MIN_SIZE_TEST'''; 100 | 101 | model path: ```MODEL.WEIGHT```; 102 | 103 | output directory: ```OUTPUT_DIR``` 104 | 105 | ### run ```sh test.sh``` 106 | 107 | 108 | ## Training 109 | Place all the training sets in ```MaskTextSpotter/datasets/``` and check ```DATASETS.TRAIN``` in the config file. 110 | ### Pretrain 111 | Trained with SynthText 112 | 113 | ```python3 -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/pretrain.yaml ``` 114 | ### Finetune 115 | Trained with a mixure of SynthText, icdar2013, icdar2015, scut-eng-char, and total-text 116 | 117 | check the initial weights in the config file. 118 | 119 | ```python3 -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/finetune.yaml ``` 120 | 121 | ## Evaluation 122 | ### Evaluation for ICDAR 2015 dataset 123 | download the [lexicons](https://drive.google.com/open?id=1u3NlpIZkE4dYmrcWo0qzU_q7ra5jvDhD) and place them like ```evaluation/lexicons/ic15/``` 124 | 125 | ``` 126 | cd evaluation/icdar2015/e2e/ 127 | # edit "result_dir" in script.py 128 | python script.py 129 | ``` 130 | 131 | ### Evaluation for Total-Text dataset (ToDo) 132 | 133 | 134 | 135 | ## Citing the related works 136 | 137 | Please cite the related works in your publications if it helps your research: 138 | 139 | @article{liao2019mask, 140 | author={M. {Liao} and P. {Lyu} and M. {He} and C. {Yao} and W. {Wu} and X. {Bai}}, 141 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 142 | title={Mask TextSpotter: An End-to-End Trainable Neural Network for Spotting Text with Arbitrary Shapes}, 143 | year={2021}, 144 | volume={43}, 145 | number={2}, 146 | pages={532-548}, 147 | doi={10.1109/TPAMI.2019.2937086}} 148 | } 149 | 150 | @inproceedings{lyu2018mask, 151 | title={Mask textspotter: An end-to-end trainable neural network for spotting text with arbitrary shapes}, 152 | author={Lyu, Pengyuan and Liao, Minghui and Yao, Cong and Wu, Wenhao and Bai, Xiang}, 153 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 154 | pages={67--83}, 155 | year={2018} 156 | } 157 | 158 | 159 | -------------------------------------------------------------------------------- /configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | # WEIGHT: "./outputs/pretrain/model_pretrain.pth" 4 | WEIGHT: "./outputs/finetune/model_finetune.pth" 5 | BACKBONE: 6 | CONV_BODY: "R-50-FPN" 7 | OUT_CHANNELS: 256 8 | RPN: 9 | USE_FPN: True 10 | ANCHOR_STRIDE: (4, 8, 16, 32, 64) 11 | PRE_NMS_TOP_N_TRAIN: 2000 12 | PRE_NMS_TOP_N_TEST: 1000 13 | POST_NMS_TOP_N_TEST: 1000 14 | FPN_POST_NMS_TOP_N_TEST: 1000 15 | ROI_HEADS: 16 | USE_FPN: True 17 | BATCH_SIZE_PER_IMAGE: 512 18 | SCORE_THRESH: 0.5 19 | ROI_BOX_HEAD: 20 | POOLER_RESOLUTION: 7 21 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 22 | POOLER_SAMPLING_RATIO: 2 23 | FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" 24 | PREDICTOR: "FPNPredictor" 25 | NUM_CLASSES: 2 26 | ROI_MASK_HEAD: 27 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 28 | FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" 29 | PREDICTOR: "SeqCharMaskRCNNC4Predictor" 30 | POOLER_RESOLUTION: 14 31 | POOLER_RESOLUTION_H: 16 32 | POOLER_RESOLUTION_W: 64 33 | POOLER_SAMPLING_RATIO: 2 34 | RESOLUTION: 28 35 | RESOLUTION_H: 32 36 | RESOLUTION_W: 128 37 | SHARE_BOX_FEATURE_EXTRACTOR: False 38 | CHAR_NUM_CLASSES: 37 39 | USE_WEIGHTED_CHAR_MASK: True 40 | MASK_BATCH_SIZE_PER_IM: 64 41 | MASK_ON: True 42 | CHAR_MASK_ON: True 43 | SEQUENCE: 44 | SEQ_ON: True 45 | NUM_CHAR: 38 46 | BOS_TOKEN: 0 47 | MAX_LENGTH: 32 48 | TEACHER_FORCE_RATIO: 1.0 49 | TWO_CONV: True 50 | MEAN_SCORE: True 51 | DATASETS: 52 | TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train") 53 | RATIOS: [0.25,0.25,0.25,0.125,0.125] 54 | TEST: ("icdar_2013_test",) 55 | # TEST: ("total_text_test",) 56 | AUG: True 57 | DATALOADER: 58 | SIZE_DIVISIBILITY: 32 59 | NUM_WORKERS: 4 60 | ASPECT_RATIO_GROUPING: False 61 | SOLVER: 62 | BASE_LR: 0.001 #0.02 63 | WARMUP_FACTOR: 0.1 64 | WEIGHT_DECAY: 0.0001 65 | STEPS: (100000, 160000) 66 | MAX_ITER: 300000 67 | IMS_PER_BATCH: 8 68 | RESUME: True 69 | OUTPUT_DIR: "./outputs/finetune" 70 | TEST: 71 | VIS: True 72 | CHAR_THRESH: 192 73 | IMS_PER_BATCH: 1 74 | INPUT: 75 | MIN_SIZE_TRAIN: (800, 1000, 1200, 1400) 76 | MAX_SIZE_TRAIN: 2000 77 | MIN_SIZE_TEST: 1000 78 | MAX_SIZE_TEST: 3333 79 | -------------------------------------------------------------------------------- /configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" 4 | WEIGHT: "./outputs/finetune/model_finetune.pth" 5 | BACKBONE: 6 | CONV_BODY: "R-50-FPN" 7 | OUT_CHANNELS: 256 8 | RPN: 9 | USE_FPN: True 10 | ANCHOR_STRIDE: (4, 8, 16, 32, 64) 11 | PRE_NMS_TOP_N_TRAIN: 2000 12 | PRE_NMS_TOP_N_TEST: 1000 13 | POST_NMS_TOP_N_TEST: 1000 14 | FPN_POST_NMS_TOP_N_TEST: 1000 15 | ROI_HEADS: 16 | USE_FPN: True 17 | BATCH_SIZE_PER_IMAGE: 512 18 | ROI_BOX_HEAD: 19 | POOLER_RESOLUTION: 7 20 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 21 | POOLER_SAMPLING_RATIO: 2 22 | FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" 23 | PREDICTOR: "FPNPredictor" 24 | NUM_CLASSES: 2 25 | ROI_MASK_HEAD: 26 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 27 | FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" 28 | PREDICTOR: "SeqCharMaskRCNNC4Predictor" 29 | POOLER_RESOLUTION_H: 16 30 | POOLER_RESOLUTION_W: 64 31 | POOLER_SAMPLING_RATIO: 2 32 | RESOLUTION: 28 33 | RESOLUTION_H: 32 34 | RESOLUTION_W: 128 35 | SHARE_BOX_FEATURE_EXTRACTOR: False 36 | CHAR_NUM_CLASSES: 37 37 | USE_WEIGHTED_CHAR_MASK: True 38 | MASK_BATCH_SIZE_PER_IM: 64 39 | MASK_ON: True 40 | CHAR_MASK_ON: True 41 | SEQUENCE: 42 | SEQ_ON: True 43 | NUM_CHAR: 38 44 | BOS_TOKEN: 0 45 | MAX_LENGTH: 32 46 | TEACHER_FORCE_RATIO: 1.0 47 | TWO_CONV: True 48 | DATASETS: 49 | TRAIN: ("synthtext_train",) 50 | TEST: ("icdar_2013_test",) 51 | DATALOADER: 52 | SIZE_DIVISIBILITY: 32 53 | NUM_WORKERS: 4 54 | ASPECT_RATIO_GROUPING: False 55 | SOLVER: 56 | BASE_LR: 0.01 #0.02 57 | WARMUP_FACTOR: 0.1 58 | WEIGHT_DECAY: 0.0001 59 | STEPS: (100000, 160000) 60 | MAX_ITER: 300000 61 | IMS_PER_BATCH: 8 62 | OUTPUT_DIR: "./outputs/pretrain" 63 | TEST: 64 | VIS: False 65 | CHAR_THRESH: 192 66 | IMS_PER_BATCH: 1 67 | INPUT: 68 | MIN_SIZE_TRAIN: (600, 800) 69 | MAX_SIZE_TRAIN: 2333 70 | MIN_SIZE_TEST: 800 71 | MAX_SIZE_TEST: 1333 72 | -------------------------------------------------------------------------------- /demo_images/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/demo_images/demo.jpg -------------------------------------------------------------------------------- /evaluation/icdar2015/e2e/weighted_editdistance.py: -------------------------------------------------------------------------------- 1 | def weighted_edit_distance(word1, word2, scores): 2 | m = len(word1) 3 | n = len(word2) 4 | dp = [[0 for __ in range(m + 1)] for __ in range(n + 1)] 5 | for j in range(m + 1): 6 | dp[0][j] = j 7 | for i in range(n + 1): 8 | dp[i][0] = i 9 | for i in range(1, n + 1): ## word2 10 | for j in range(1, m + 1): ## word1 11 | delect_cost = ed_delect_cost(j-1, i-1, word1, word2, scores) ## delect a[i] 12 | insert_cost = ed_insert_cost(j-1, i-1, word1, word2, scores) ## insert b[j] 13 | if word1[j - 1] != word2[i - 1]: 14 | replace_cost = ed_replace_cost(j-1, i-1, word1, word2, scores) ## replace a[i] with b[j] 15 | else: 16 | replace_cost = 0 17 | dp[i][j] = min(dp[i-1][j] + insert_cost, dp[i][j-1] + delect_cost, dp[i-1][j-1] + replace_cost) 18 | 19 | return dp[n][m] 20 | 21 | def ed_delect_cost(j, i, word1, word2, scores): 22 | ## delect a[i] 23 | c = char2num(word1[j]) 24 | return scores[c][j] 25 | 26 | 27 | def ed_insert_cost(i, j, word1, word2, scores): 28 | ## insert b[j] 29 | if i < len(word1) - 1: 30 | c1 = char2num(word1[i]) 31 | c2 = char2num(word1[i + 1]) 32 | return (scores[c1][i] + scores[c2][i+1])/2 33 | else: 34 | c1 = char2num(word1[i]) 35 | return scores[c1][i] 36 | 37 | 38 | def ed_replace_cost(i, j, word1, word2, scores): 39 | ## replace a[i] with b[j] 40 | c1 = char2num(word1[i]) 41 | c2 = char2num(word2[j]) 42 | # if word1 == "eeatpisaababarait".upper(): 43 | # print(scores[c2][i]/scores[c1][i]) 44 | 45 | return max(1 - scores[c2][i]/scores[c1][i]*5, 0) 46 | 47 | def char2num(char): 48 | if char in '0123456789': 49 | num = ord(char) - ord('0') + 1 50 | elif char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': 51 | num = ord(char.lower()) - ord('a') + 11 52 | else: 53 | print('error symbol', char) 54 | exit() 55 | return num - 1 -------------------------------------------------------------------------------- /evaluation/icdar2015/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/evaluation/icdar2015/gt.zip -------------------------------------------------------------------------------- /maskrcnn_benchmark/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .defaults import _C as cfg 3 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/ROIAlign.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | 4 | #include "cpu/vision.h" 5 | 6 | #ifdef WITH_CUDA 7 | #include "cuda/vision.h" 8 | #endif 9 | 10 | // Interface for Python 11 | at::Tensor ROIAlign_forward(const at::Tensor& input, 12 | const at::Tensor& rois, 13 | const float spatial_scale, 14 | const int pooled_height, 15 | const int pooled_width, 16 | const int sampling_ratio) { 17 | if (input.type().is_cuda()) { 18 | #ifdef WITH_CUDA 19 | return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); 20 | #else 21 | AT_ERROR("Not compiled with GPU support"); 22 | #endif 23 | } 24 | return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); 25 | } 26 | 27 | at::Tensor ROIAlign_backward(const at::Tensor& grad, 28 | const at::Tensor& rois, 29 | const float spatial_scale, 30 | const int pooled_height, 31 | const int pooled_width, 32 | const int batch_size, 33 | const int channels, 34 | const int height, 35 | const int width, 36 | const int sampling_ratio) { 37 | if (grad.type().is_cuda()) { 38 | #ifdef WITH_CUDA 39 | return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); 40 | #else 41 | AT_ERROR("Not compiled with GPU support"); 42 | #endif 43 | } 44 | AT_ERROR("Not implemented on the CPU"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/ROIPool.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | 4 | #include "cpu/vision.h" 5 | 6 | #ifdef WITH_CUDA 7 | #include "cuda/vision.h" 8 | #endif 9 | 10 | 11 | std::tuple ROIPool_forward(const at::Tensor& input, 12 | const at::Tensor& rois, 13 | const float spatial_scale, 14 | const int pooled_height, 15 | const int pooled_width) { 16 | if (input.type().is_cuda()) { 17 | #ifdef WITH_CUDA 18 | return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); 19 | #else 20 | AT_ERROR("Not compiled with GPU support"); 21 | #endif 22 | } 23 | AT_ERROR("Not implemented on the CPU"); 24 | } 25 | 26 | at::Tensor ROIPool_backward(const at::Tensor& grad, 27 | const at::Tensor& input, 28 | const at::Tensor& rois, 29 | const at::Tensor& argmax, 30 | const float spatial_scale, 31 | const int pooled_height, 32 | const int pooled_width, 33 | const int batch_size, 34 | const int channels, 35 | const int height, 36 | const int width) { 37 | if (grad.type().is_cuda()) { 38 | #ifdef WITH_CUDA 39 | return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); 40 | #else 41 | AT_ERROR("Not compiled with GPU support"); 42 | #endif 43 | } 44 | AT_ERROR("Not implemented on the CPU"); 45 | } 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include "cpu/vision.h" 3 | 4 | 5 | template 6 | at::Tensor nms_cpu_kernel(const at::Tensor& dets, 7 | const at::Tensor& scores, 8 | const float threshold) { 9 | AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor"); 10 | AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor"); 11 | AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); 12 | 13 | if (dets.numel() == 0) { 14 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 15 | } 16 | 17 | auto x1_t = dets.select(1, 0).contiguous(); 18 | auto y1_t = dets.select(1, 1).contiguous(); 19 | auto x2_t = dets.select(1, 2).contiguous(); 20 | auto y2_t = dets.select(1, 3).contiguous(); 21 | 22 | at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); 23 | 24 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); 25 | 26 | auto ndets = dets.size(0); 27 | at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU)); 28 | 29 | auto suppressed = suppressed_t.data(); 30 | auto order = order_t.data(); 31 | auto x1 = x1_t.data(); 32 | auto y1 = y1_t.data(); 33 | auto x2 = x2_t.data(); 34 | auto y2 = y2_t.data(); 35 | auto areas = areas_t.data(); 36 | 37 | for (int64_t _i = 0; _i < ndets; _i++) { 38 | auto i = order[_i]; 39 | if (suppressed[i] == 1) 40 | continue; 41 | auto ix1 = x1[i]; 42 | auto iy1 = y1[i]; 43 | auto ix2 = x2[i]; 44 | auto iy2 = y2[i]; 45 | auto iarea = areas[i]; 46 | 47 | for (int64_t _j = _i + 1; _j < ndets; _j++) { 48 | auto j = order[_j]; 49 | if (suppressed[j] == 1) 50 | continue; 51 | auto xx1 = std::max(ix1, x1[j]); 52 | auto yy1 = std::max(iy1, y1[j]); 53 | auto xx2 = std::min(ix2, x2[j]); 54 | auto yy2 = std::min(iy2, y2[j]); 55 | 56 | auto w = std::max(static_cast(0), xx2 - xx1 + 1); 57 | auto h = std::max(static_cast(0), yy2 - yy1 + 1); 58 | auto inter = w * h; 59 | auto ovr = inter / (iarea + areas[j] - inter); 60 | if (ovr >= threshold) 61 | suppressed[j] = 1; 62 | } 63 | } 64 | return at::nonzero(suppressed_t == 0).squeeze(1); 65 | } 66 | 67 | at::Tensor nms_cpu(const at::Tensor& dets, 68 | const at::Tensor& scores, 69 | const float threshold) { 70 | at::Tensor result; 71 | AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] { 72 | result = nms_cpu_kernel(dets, scores, threshold); 73 | }); 74 | return result; 75 | } 76 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/cpu/vision.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include 4 | 5 | 6 | at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, 7 | const at::Tensor& rois, 8 | const float spatial_scale, 9 | const int pooled_height, 10 | const int pooled_width, 11 | const int sampling_ratio); 12 | 13 | 14 | at::Tensor nms_cpu(const at::Tensor& dets, 15 | const at::Tensor& scores, 16 | const float threshold); 17 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/cuda/nms.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 12 | 13 | __device__ inline float devIoU(float const * const a, float const * const b) { 14 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 15 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 16 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 17 | float interS = width * height; 18 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 19 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 20 | return interS / (Sa + Sb - interS); 21 | } 22 | 23 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 24 | const float *dev_boxes, unsigned long long *dev_mask) { 25 | const int row_start = blockIdx.y; 26 | const int col_start = blockIdx.x; 27 | 28 | // if (row_start > col_start) return; 29 | 30 | const int row_size = 31 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 32 | const int col_size = 33 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 34 | 35 | __shared__ float block_boxes[threadsPerBlock * 5]; 36 | if (threadIdx.x < col_size) { 37 | block_boxes[threadIdx.x * 5 + 0] = 38 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 39 | block_boxes[threadIdx.x * 5 + 1] = 40 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 41 | block_boxes[threadIdx.x * 5 + 2] = 42 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 43 | block_boxes[threadIdx.x * 5 + 3] = 44 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 45 | block_boxes[threadIdx.x * 5 + 4] = 46 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 47 | } 48 | __syncthreads(); 49 | 50 | if (threadIdx.x < row_size) { 51 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 52 | const float *cur_box = dev_boxes + cur_box_idx * 5; 53 | int i = 0; 54 | unsigned long long t = 0; 55 | int start = 0; 56 | if (row_start == col_start) { 57 | start = threadIdx.x + 1; 58 | } 59 | for (i = start; i < col_size; i++) { 60 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 61 | t |= 1ULL << i; 62 | } 63 | } 64 | const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); 65 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 66 | } 67 | } 68 | 69 | // boxes is a N x 5 tensor 70 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { 71 | using scalar_t = float; 72 | AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); 73 | auto scores = boxes.select(1, 4); 74 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); 75 | auto boxes_sorted = boxes.index_select(0, order_t); 76 | 77 | int boxes_num = boxes.size(0); 78 | 79 | const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); 80 | 81 | scalar_t* boxes_dev = boxes_sorted.data(); 82 | 83 | THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState 84 | 85 | unsigned long long* mask_dev = NULL; 86 | //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, 87 | // boxes_num * col_blocks * sizeof(unsigned long long))); 88 | 89 | mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); 90 | 91 | dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), 92 | THCCeilDiv(boxes_num, threadsPerBlock)); 93 | dim3 threads(threadsPerBlock); 94 | nms_kernel<<>>(boxes_num, 95 | nms_overlap_thresh, 96 | boxes_dev, 97 | mask_dev); 98 | 99 | std::vector mask_host(boxes_num * col_blocks); 100 | THCudaCheck(cudaMemcpy(&mask_host[0], 101 | mask_dev, 102 | sizeof(unsigned long long) * boxes_num * col_blocks, 103 | cudaMemcpyDeviceToHost)); 104 | 105 | std::vector remv(col_blocks); 106 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 107 | 108 | at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); 109 | int64_t* keep_out = keep.data(); 110 | 111 | int num_to_keep = 0; 112 | for (int i = 0; i < boxes_num; i++) { 113 | int nblock = i / threadsPerBlock; 114 | int inblock = i % threadsPerBlock; 115 | 116 | if (!(remv[nblock] & (1ULL << inblock))) { 117 | keep_out[num_to_keep++] = i; 118 | unsigned long long *p = &mask_host[0] + i * col_blocks; 119 | for (int j = nblock; j < col_blocks; j++) { 120 | remv[j] |= p[j]; 121 | } 122 | } 123 | } 124 | 125 | THCudaFree(state, mask_dev); 126 | // TODO improve this part 127 | return std::get<0>(order_t.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}).sort(0, false)); 128 | } 129 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/cuda/vision.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include 4 | 5 | 6 | at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, 7 | const at::Tensor& rois, 8 | const float spatial_scale, 9 | const int pooled_height, 10 | const int pooled_width, 11 | const int sampling_ratio); 12 | 13 | at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, 14 | const at::Tensor& rois, 15 | const float spatial_scale, 16 | const int pooled_height, 17 | const int pooled_width, 18 | const int batch_size, 19 | const int channels, 20 | const int height, 21 | const int width, 22 | const int sampling_ratio); 23 | 24 | 25 | std::tuple ROIPool_forward_cuda(const at::Tensor& input, 26 | const at::Tensor& rois, 27 | const float spatial_scale, 28 | const int pooled_height, 29 | const int pooled_width); 30 | 31 | at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, 32 | const at::Tensor& input, 33 | const at::Tensor& rois, 34 | const at::Tensor& argmax, 35 | const float spatial_scale, 36 | const int pooled_height, 37 | const int pooled_width, 38 | const int batch_size, 39 | const int channels, 40 | const int height, 41 | const int width); 42 | 43 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); 44 | 45 | 46 | at::Tensor compute_flow_cuda(const at::Tensor& boxes, 47 | const int height, 48 | const int width); 49 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/nms.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | 10 | at::Tensor nms(const at::Tensor& dets, 11 | const at::Tensor& scores, 12 | const float threshold) { 13 | 14 | if (dets.type().is_cuda()) { 15 | #ifdef WITH_CUDA 16 | // TODO raise error if not compiled with CUDA 17 | if (dets.numel() == 0) 18 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 19 | auto b = at::cat({dets, scores.unsqueeze(1)}, 1); 20 | return nms_cuda(b, threshold); 21 | #else 22 | AT_ERROR("Not compiled with GPU support"); 23 | #endif 24 | } 25 | 26 | at::Tensor result = nms_cpu(dets, scores, threshold); 27 | return result; 28 | } 29 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include "nms.h" 3 | #include "ROIAlign.h" 4 | #include "ROIPool.h" 5 | 6 | 7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 8 | m.def("nms", &nms, "non-maximum suppression"); 9 | m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); 10 | m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); 11 | m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); 12 | m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); 13 | } 14 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import make_data_loader 3 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import bisect 3 | import logging 4 | 5 | import torch.utils.data 6 | from maskrcnn_benchmark.utils.comm import get_world_size 7 | from maskrcnn_benchmark.utils.imports import import_file 8 | 9 | from . import datasets as D 10 | from . import samplers 11 | 12 | from .collate_batch import BatchCollator 13 | from .transforms import build_transforms 14 | 15 | 16 | def build_dataset(cfg,dataset_list, transforms, dataset_catalog, is_train=True): 17 | """ 18 | Arguments: 19 | dataset_list (list[str]): Contains the names of the datasets, i.e., 20 | coco_2014_trian, coco_2014_val, etc 21 | transforms (callable): transforms to apply to each (image, target) sample 22 | dataset_catalog (DatasetCatalog): contains the information on how to 23 | construct a dataset. 24 | is_train (bool): whether to setup the dataset for training or testing 25 | """ 26 | if not isinstance(dataset_list, (list, tuple)): 27 | raise RuntimeError( 28 | "dataset_list should be a list of strings, got {}".format(dataset_list)) 29 | datasets = [] 30 | for dataset_name in dataset_list: 31 | data = dataset_catalog.get(dataset_name) 32 | factory = getattr(D, data["factory"]) 33 | args = data["args"] 34 | # for COCODataset, we want to remove images without annotations 35 | # during training 36 | if data["factory"] == "COCODataset": 37 | args["remove_images_without_annotations"] = is_train 38 | args["transforms"] = transforms 39 | # make dataset from factory 40 | dataset = factory(**args) 41 | datasets.append(dataset) 42 | 43 | # for testing, return a list of datasets 44 | if not is_train: 45 | return datasets 46 | 47 | # for training, concatenate all datasets into a single one 48 | dataset = datasets[0] 49 | if len(datasets) > 1: 50 | dataset=D.MixDataset(datasets,cfg.DATASETS.RATIOS) 51 | # dataset = D.ConcatDataset(datasets) 52 | 53 | return [dataset] 54 | 55 | 56 | def make_data_sampler(dataset, shuffle, distributed): 57 | if distributed: 58 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 59 | if shuffle: 60 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 61 | else: 62 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 63 | return sampler 64 | 65 | 66 | def _quantize(x, bins): 67 | bins = sorted(bins.copy()) 68 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 69 | return quantized 70 | 71 | 72 | def _compute_aspect_ratios(dataset): 73 | aspect_ratios = [] 74 | for i in range(len(dataset)): 75 | img_info = dataset.get_img_info(i) 76 | aspect_ratio = float(img_info["height"]) / float(img_info["width"]) 77 | aspect_ratios.append(aspect_ratio) 78 | return aspect_ratios 79 | 80 | 81 | def make_batch_data_sampler( 82 | dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0 83 | ): 84 | if aspect_grouping: 85 | if not isinstance(aspect_grouping, (list, tuple)): 86 | aspect_grouping = [aspect_grouping] 87 | aspect_ratios = _compute_aspect_ratios(dataset) 88 | group_ids = _quantize(aspect_ratios, aspect_grouping) 89 | batch_sampler = samplers.GroupedBatchSampler( 90 | sampler, group_ids, images_per_batch, drop_uneven=False 91 | ) 92 | else: 93 | batch_sampler = torch.utils.data.sampler.BatchSampler( 94 | sampler, images_per_batch, drop_last=False 95 | ) 96 | if num_iters is not None: 97 | batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter) 98 | return batch_sampler 99 | 100 | 101 | def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): 102 | num_gpus = get_world_size() 103 | if is_train: 104 | images_per_batch = cfg.SOLVER.IMS_PER_BATCH 105 | assert ( 106 | images_per_batch % num_gpus == 0 107 | ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number " 108 | "of GPUs ({}) used.".format(images_per_batch, num_gpus) 109 | images_per_gpu = images_per_batch // num_gpus 110 | shuffle = True 111 | num_iters = cfg.SOLVER.MAX_ITER 112 | else: 113 | images_per_batch = cfg.TEST.IMS_PER_BATCH 114 | assert ( 115 | images_per_batch % num_gpus == 0 116 | ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number " 117 | "of GPUs ({}) used.".format(images_per_batch, num_gpus) 118 | images_per_gpu = images_per_batch // num_gpus 119 | shuffle = False if not is_distributed else True 120 | num_iters = None 121 | start_iter = 0 122 | 123 | if images_per_gpu > 1: 124 | logger = logging.getLogger(__name__) 125 | logger.warning( 126 | "When using more than one image per GPU you may encounter " 127 | "an out-of-memory (OOM) error if your GPU does not have " 128 | "sufficient memory. If this happens, you can reduce " 129 | "SOLVER.IMS_PER_BATCH (for training) or " 130 | "TEST.IMS_PER_BATCH (for inference). For training, you must " 131 | "also adjust the learning rate and schedule length according " 132 | "to the linear scaling rule. See for example: " 133 | "https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14" 134 | ) 135 | 136 | # group images which have similar aspect ratio. In this case, we only 137 | # group in two cases: those with width / height > 1, and the other way around, 138 | # but the code supports more general grouping strategy 139 | aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else [] 140 | 141 | paths_catalog = import_file( 142 | "maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True 143 | ) 144 | DatasetCatalog = paths_catalog.DatasetCatalog 145 | dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST 146 | 147 | transforms = build_transforms(cfg, is_train) 148 | datasets = build_dataset(cfg,dataset_list, transforms, DatasetCatalog, is_train) 149 | 150 | data_loaders = [] 151 | for dataset in datasets: 152 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 153 | batch_sampler = make_batch_data_sampler( 154 | dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter 155 | ) 156 | collator = BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY) 157 | num_workers = cfg.DATALOADER.NUM_WORKERS 158 | data_loader = torch.utils.data.DataLoader( 159 | dataset, 160 | num_workers=num_workers, 161 | batch_sampler=batch_sampler, 162 | collate_fn=collator, 163 | ) 164 | data_loaders.append(data_loader) 165 | if is_train: 166 | # during training, a single (possibly concatenated) data_loader is returned 167 | assert len(data_loaders) == 1 168 | return data_loaders[0] 169 | return data_loaders 170 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from maskrcnn_benchmark.structures.image_list import to_image_list 3 | 4 | 5 | class BatchCollator(object): 6 | """ 7 | From a list of samples from the dataset, 8 | returns the batched images and targets. 9 | This should be passed to the DataLoader 10 | """ 11 | 12 | def __init__(self, size_divisible=0): 13 | self.size_divisible = size_divisible 14 | 15 | def __call__(self, batch): 16 | transposed_batch = list(zip(*batch)) 17 | images = to_image_list(transposed_batch[0], self.size_divisible) 18 | targets = transposed_batch[1] 19 | img_ids = transposed_batch[2] 20 | return images, targets, img_ids 21 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .coco import COCODataset 3 | from .concat_dataset import ConcatDataset,MixDataset 4 | from .icdar import IcdarDataset 5 | from .synthtext import SynthtextDataset 6 | from .scut import ScutDataset 7 | from .total_text import TotaltextDataset 8 | __all__ = ["COCODataset", "ConcatDataset","IcdarDataset","SynthtextDataset","MixDataset","ScutDataset","TotaltextDataset"] 9 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/datasets/coco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torchvision 4 | 5 | from maskrcnn_benchmark.structures.bounding_box import BoxList 6 | from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask 7 | 8 | 9 | class COCODataset(torchvision.datasets.coco.CocoDetection): 10 | def __init__( 11 | self, ann_file, root, remove_images_without_annotations, transforms=None 12 | ): 13 | super(COCODataset, self).__init__(root, ann_file) 14 | 15 | # sort indices for reproducible results 16 | self.ids = sorted(self.ids) 17 | 18 | # filter images without detection annotations 19 | if remove_images_without_annotations: 20 | self.ids = [ 21 | img_id 22 | for img_id in self.ids 23 | if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0 24 | ] 25 | 26 | self.json_category_id_to_contiguous_id = { 27 | v: i + 1 for i, v in enumerate(self.coco.getCatIds()) 28 | } 29 | self.contiguous_category_id_to_json_id = { 30 | v: k for k, v in self.json_category_id_to_contiguous_id.items() 31 | } 32 | self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} 33 | self.transforms = transforms 34 | 35 | def __getitem__(self, idx): 36 | img, anno = super(COCODataset, self).__getitem__(idx) 37 | 38 | # filter crowd annotations 39 | # TODO might be better to add an extra field 40 | anno = [obj for obj in anno if obj["iscrowd"] == 0] 41 | 42 | boxes = [obj["bbox"] for obj in anno] 43 | boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes 44 | target = BoxList(boxes, img.size, mode="xywh",use_char_ann=False).convert("xyxy") 45 | 46 | classes = [obj["category_id"] for obj in anno] 47 | classes = [self.json_category_id_to_contiguous_id[c] for c in classes] 48 | classes = torch.tensor(classes) 49 | target.add_field("labels", classes) 50 | 51 | masks = [obj["segmentation"] for obj in anno] 52 | masks = SegmentationMask(masks, img.size) 53 | target.add_field("masks", masks) 54 | 55 | target = target.clip_to_image(remove_empty=True) 56 | 57 | if self.transforms is not None: 58 | img, target = self.transforms(img, target) 59 | 60 | return img, target, idx 61 | 62 | def get_img_info(self, index): 63 | img_id = self.id_to_img_map[index] 64 | img_data = self.coco.imgs[img_id] 65 | return img_data 66 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import bisect 3 | import numpy as np 4 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 5 | 6 | 7 | class ConcatDataset(_ConcatDataset): 8 | """ 9 | Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra 10 | method for querying the sizes of the image 11 | """ 12 | 13 | def get_idxs(self, idx): 14 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 15 | if dataset_idx == 0: 16 | sample_idx = idx 17 | else: 18 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 19 | return dataset_idx, sample_idx 20 | 21 | def get_img_info(self, idx): 22 | dataset_idx, sample_idx = self.get_idxs(idx) 23 | return self.datasets[dataset_idx].get_img_info(sample_idx) 24 | 25 | class MixDataset(object): 26 | def __init__(self,datasets,ratios): 27 | self.datasets=datasets 28 | self.ratios=ratios 29 | self.lengths=[] 30 | for dataset in self.datasets: 31 | self.lengths.append(len(dataset)) 32 | self.lengths=np.array(self.lengths) 33 | self.seperate_inds=[] 34 | s=0 35 | for i in self.ratios[:-1]: 36 | s+=i 37 | self.seperate_inds.append(s) 38 | 39 | def __len__(self): 40 | return self.lengths.sum() 41 | 42 | def __getitem__(self, item): 43 | i=np.random.rand() 44 | ind=bisect.bisect_right(self.seperate_inds,i) 45 | b_ind=np.random.randint(self.lengths[ind]) 46 | return self.datasets[ind][b_ind] 47 | #def get_img_info(self,idx): 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/datasets/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | Simple dataset class that wraps a list of path names 4 | """ 5 | 6 | from PIL import Image 7 | 8 | from maskrcnn_benchmark.structures.bounding_box import BoxList 9 | 10 | 11 | class ListDataset(object): 12 | def __init__(self, image_lists, transforms=None): 13 | self.image_lists = image_lists 14 | self.transforms = transforms 15 | 16 | def __getitem__(self, item): 17 | img = Image.open(self.image_lists[item]).convert("RGB") 18 | 19 | # dummy target 20 | w, h = img.size 21 | target = BoxList([[0, 0, w, h]], img.size, mode="xyxy") 22 | 23 | if self.transforms is not None: 24 | img, target = self.transforms(img, target) 25 | 26 | return img, target 27 | 28 | def __len__(self): 29 | return len(self.image_lists) 30 | 31 | def get_img_info(self, item): 32 | """ 33 | Return the image dimensions for the image, without 34 | loading and pre-processing it 35 | """ 36 | pass 37 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .distributed import DistributedSampler 3 | from .grouped_batch_sampler import GroupedBatchSampler 4 | from .iteration_based_batch_sampler import IterationBasedBatchSampler 5 | 6 | __all__ = ["DistributedSampler", "GroupedBatchSampler", "IterationBasedBatchSampler"] 7 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/samplers/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Code is copy-pasted exactly as in torch.utils.data.distributed, 3 | # with a modification in the import to use the deprecated backend 4 | # FIXME remove this once c10d fixes the bug it has 5 | import math 6 | import torch 7 | # import torch.distributed.deprecated as dist 8 | import torch.distributed as dist 9 | from torch.utils.data.sampler import Sampler 10 | 11 | 12 | class DistributedSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | It is especially useful in conjunction with 15 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 16 | process can pass a DistributedSampler instance as a DataLoader sampler, 17 | and load a subset of the original dataset that is exclusive to it. 18 | .. note:: 19 | Dataset is assumed to be of constant size. 20 | Arguments: 21 | dataset: Dataset used for sampling. 22 | num_replicas (optional): Number of processes participating in 23 | distributed training. 24 | rank (optional): Rank of the current process within num_replicas. 25 | """ 26 | 27 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 28 | if num_replicas is None: 29 | if not dist.is_available(): 30 | raise RuntimeError("Requires distributed package to be available") 31 | num_replicas = dist.get_world_size() 32 | if rank is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | rank = dist.get_rank() 36 | self.dataset = dataset 37 | self.num_replicas = num_replicas 38 | self.rank = rank 39 | self.epoch = 0 40 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 41 | self.total_size = self.num_samples * self.num_replicas 42 | self.shuffle = True 43 | 44 | def __iter__(self): 45 | if self.shuffle: 46 | # deterministically shuffle based on epoch 47 | g = torch.Generator() 48 | g.manual_seed(self.epoch) 49 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 50 | else: 51 | indices = torch.arange(len(self.dataset)).tolist() 52 | 53 | # add extra samples to make it evenly divisible 54 | indices += indices[: (self.total_size - len(indices))] 55 | assert len(indices) == self.total_size 56 | 57 | # subsample 58 | offset = self.num_samples * self.rank 59 | indices = indices[offset : offset + self.num_samples] 60 | assert len(indices) == self.num_samples 61 | 62 | return iter(indices) 63 | 64 | def __len__(self): 65 | return self.num_samples 66 | 67 | def set_epoch(self, epoch): 68 | self.epoch = epoch 69 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | class GroupedBatchSampler(BatchSampler): 10 | """ 11 | Wraps another sampler to yield a mini-batch of indices. 12 | It enforces that elements from the same group should appear in groups of batch_size. 13 | It also tries to provide mini-batches which follows an ordering which is 14 | as close as possible to the ordering from the original sampler. 15 | 16 | Arguments: 17 | sampler (Sampler): Base sampler. 18 | batch_size (int): Size of mini-batch. 19 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose 20 | size is less than ``batch_size`` 21 | 22 | """ 23 | 24 | def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): 25 | if not isinstance(sampler, Sampler): 26 | raise ValueError( 27 | "sampler should be an instance of " 28 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 29 | ) 30 | self.sampler = sampler 31 | self.group_ids = torch.as_tensor(group_ids) 32 | assert self.group_ids.dim() == 1 33 | self.batch_size = batch_size 34 | self.drop_uneven = drop_uneven 35 | 36 | self.groups = torch.unique(self.group_ids).sort(0)[0] 37 | 38 | self._can_reuse_batches = False 39 | 40 | def _prepare_batches(self): 41 | dataset_size = len(self.group_ids) 42 | # get the sampled indices from the sampler 43 | sampled_ids = torch.as_tensor(list(self.sampler)) 44 | # potentially not all elements of the dataset were sampled 45 | # by the sampler (e.g., DistributedSampler). 46 | # construct a tensor which contains -1 if the element was 47 | # not sampled, and a non-negative number indicating the 48 | # order where the element was sampled. 49 | # for example. if sampled_ids = [3, 1] and dataset_size = 5, 50 | # the order is [-1, 1, -1, 0, -1] 51 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 52 | order[sampled_ids] = torch.arange(len(sampled_ids)) 53 | 54 | # get a mask with the elements that were sampled 55 | mask = order >= 0 56 | 57 | # find the elements that belong to each individual cluster 58 | clusters = [(self.group_ids == i) & mask for i in self.groups] 59 | # get relative order of the elements inside each cluster 60 | # that follows the order from the sampler 61 | relative_order = [order[cluster] for cluster in clusters] 62 | # with the relative order, find the absolute order in the 63 | # sampled space 64 | permutation_ids = [s[s.sort()[1]] for s in relative_order] 65 | # permute each cluster so that they follow the order from 66 | # the sampler 67 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 68 | 69 | # splits each cluster in batch_size, and merge as a list of tensors 70 | splits = [c.split(self.batch_size) for c in permuted_clusters] 71 | merged = tuple(itertools.chain.from_iterable(splits)) 72 | # now each batch internally has the right order, but 73 | # they are grouped by clusters. Find the permutation between 74 | # different batches that brings them as close as possible to 75 | # the order that we have in the sampler. For that, we will consider the 76 | # ordering as coming from the first element of each batch, and sort 77 | # correspondingly 78 | first_element_of_batch = [t[0].item() for t in merged] 79 | # get and inverse mapping from sampled indices and the position where 80 | # they occur (as returned by the sampler) 81 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 82 | # from the first element in each batch, get a relative ordering 83 | first_index_of_batch = torch.as_tensor( 84 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 85 | ) 86 | 87 | # permute the batches so that they approximately follow the order 88 | # from the sampler 89 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 90 | # finally, permute the batches 91 | batches = [merged[i].tolist() for i in permutation_order] 92 | 93 | if self.drop_uneven: 94 | kept = [] 95 | for batch in batches: 96 | if len(batch) == self.batch_size: 97 | kept.append(batch) 98 | batches = kept 99 | return batches 100 | 101 | def __iter__(self): 102 | if self._can_reuse_batches: 103 | batches = self._batches 104 | self._can_reuse_batches = False 105 | else: 106 | batches = self._prepare_batches() 107 | self._batches = batches 108 | return iter(batches) 109 | 110 | def __len__(self): 111 | if not hasattr(self, "_batches"): 112 | self._batches = self._prepare_batches() 113 | self._can_reuse_batches = True 114 | return len(self._batches) 115 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch.utils.data.sampler import BatchSampler 3 | 4 | 5 | class IterationBasedBatchSampler(BatchSampler): 6 | """ 7 | Wraps a BatchSampler, resampling from it until 8 | a specified number of iterations have been sampled 9 | """ 10 | 11 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 12 | self.batch_sampler = batch_sampler 13 | self.num_iterations = num_iterations 14 | self.start_iter = start_iter 15 | 16 | def __iter__(self): 17 | iteration = self.start_iter 18 | while iteration <= self.num_iterations: 19 | # if the underlying sampler has a set_epoch method, like 20 | # DistributedSampler, used for making each process see 21 | # a different split of the dataset, then set it 22 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 23 | self.batch_sampler.sampler.set_epoch(iteration) 24 | for batch in self.batch_sampler: 25 | iteration += 1 26 | if iteration > self.num_iterations: 27 | break 28 | yield batch 29 | 30 | def __len__(self): 31 | return self.num_iterations 32 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .transforms import Compose 3 | from .transforms import Resize 4 | from .transforms import RandomHorizontalFlip 5 | from .transforms import ToTensor 6 | from .transforms import Normalize 7 | 8 | from .build import build_transforms 9 | 10 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from . import transforms as T 3 | 4 | 5 | def build_transforms(cfg, is_train=True): 6 | to_bgr255 = cfg.INPUT.TO_BGR255 7 | normalize_transform = T.Normalize( 8 | mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255 9 | ) 10 | if is_train: 11 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 12 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 13 | # flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN 14 | flip_prob = 0 15 | rotate_prob = 0.5 16 | pixel_aug_prob = 0.2 17 | random_crop_prob = cfg.DATASETS.RANDOM_CROP_PROB 18 | else: 19 | min_size = cfg.INPUT.MIN_SIZE_TEST 20 | max_size = cfg.INPUT.MAX_SIZE_TEST 21 | flip_prob = 0 22 | rotate_prob = 0 23 | pixel_aug_prob = 0 24 | random_crop_prob = 0 25 | 26 | to_bgr255 = cfg.INPUT.TO_BGR255 27 | normalize_transform = T.Normalize( 28 | mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255 29 | ) 30 | if cfg.DATASETS.AUG and is_train: 31 | transform = T.Compose( 32 | [ 33 | T.RandomCrop(random_crop_prob), 34 | T.RandomBrightness(pixel_aug_prob), 35 | T.RandomContrast(pixel_aug_prob), 36 | T.RandomHue(pixel_aug_prob), 37 | T.RandomSaturation(pixel_aug_prob), 38 | T.RandomGamma(pixel_aug_prob), 39 | T.RandomRotate(rotate_prob), 40 | T.Resize(min_size, max_size), 41 | # T.RandomHorizontalFlip(flip_prob), 42 | T.ToTensor(), 43 | normalize_transform, 44 | ] 45 | ) 46 | else: 47 | transform = T.Compose( 48 | [ 49 | T.Resize(min_size, max_size), 50 | # T.RandomHorizontalFlip(flip_prob), 51 | T.ToTensor(), 52 | normalize_transform, 53 | ] 54 | ) 55 | return transform 56 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/engine/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import datetime 3 | import logging 4 | import time 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from maskrcnn_benchmark.utils.comm import get_world_size, is_main_process 10 | from maskrcnn_benchmark.utils.metric_logger import MetricLogger 11 | 12 | def reduce_loss_dict(loss_dict): 13 | """ 14 | Reduce the loss dictionary from all processes so that process with rank 15 | 0 has the averaged results. Returns a dict with the same fields as 16 | loss_dict, after reduction. 17 | """ 18 | world_size = get_world_size() 19 | if world_size < 2: 20 | return loss_dict 21 | with torch.no_grad(): 22 | loss_names = [] 23 | all_losses = [] 24 | for k, v in loss_dict.items(): 25 | loss_names.append(k) 26 | all_losses.append(v) 27 | all_losses = torch.stack(all_losses, dim=0) 28 | dist.reduce(all_losses, dst=0) 29 | if dist.get_rank() == 0: 30 | # only main process gets accumulated, so only divide by 31 | # world_size in this case 32 | all_losses /= world_size 33 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 34 | return reduced_losses 35 | 36 | 37 | def do_train( 38 | model, 39 | data_loader, 40 | optimizer, 41 | scheduler, 42 | checkpointer, 43 | device, 44 | checkpoint_period, 45 | arguments, 46 | tb_logger, 47 | cfg, 48 | ): 49 | logger = logging.getLogger("maskrcnn_benchmark.trainer") 50 | logger.info("Start training") 51 | meters = MetricLogger(delimiter=" ") 52 | max_iter = len(data_loader) 53 | start_iter = arguments["iteration"] 54 | model.train() 55 | start_training_time = time.time() 56 | end = time.time() 57 | for iteration, (images, targets, _) in enumerate(data_loader, start_iter): 58 | data_time = time.time() - end 59 | arguments["iteration"] = iteration 60 | 61 | scheduler.step() 62 | 63 | images = images.to(device) 64 | targets = [target.to(device) for target in targets] 65 | 66 | loss_dict = model(images, targets) 67 | 68 | losses = sum(loss for loss in loss_dict.values()) 69 | # reduce losses over all GPUs for logging purposes 70 | loss_dict_reduced = reduce_loss_dict(loss_dict) 71 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 72 | meters.update(loss=losses_reduced, **loss_dict_reduced) 73 | 74 | optimizer.zero_grad() 75 | losses.backward() 76 | if cfg.SOLVER.USE_ADAM: 77 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 78 | optimizer.step() 79 | 80 | batch_time = time.time() - end 81 | end = time.time() 82 | meters.update(time=batch_time, data=data_time) 83 | 84 | eta_seconds = meters.time.global_avg * (max_iter - iteration) 85 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 86 | 87 | if iteration % 20 == 0 or iteration == (max_iter - 1): 88 | logger.info( 89 | meters.delimiter.join( 90 | [ 91 | "eta: {eta}", 92 | "iter: {iter}", 93 | "{meters}", 94 | "lr: {lr:.6f}", 95 | "max mem: {memory:.0f}", 96 | ] 97 | ).format( 98 | eta=eta_string, 99 | iter=iteration, 100 | meters=str(meters), 101 | lr=optimizer.param_groups[0]["lr"], 102 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, 103 | ) 104 | ) 105 | if is_main_process(): 106 | for tag, value in loss_dict_reduced.items(): 107 | tb_logger.scalar_summary(tag, value.item(), iteration) 108 | if iteration % checkpoint_period == 0 and iteration > 0: 109 | checkpointer.save("model_{:07d}".format(iteration), **arguments) 110 | 111 | checkpointer.save("model_{:07d}".format(iteration), **arguments) 112 | total_training_time = time.time() - start_training_time 113 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 114 | logger.info( 115 | "Total training time: {} ({:.4f} s / it)".format( 116 | total_time_str, total_training_time / (max_iter) 117 | ) 118 | ) 119 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .batch_norm import FrozenBatchNorm2d 5 | from .misc import Conv2d 6 | from .misc import ConvTranspose2d 7 | from .misc import interpolate 8 | from .nms import nms 9 | from .roi_align import ROIAlign 10 | from .roi_align import roi_align 11 | from .roi_pool import ROIPool 12 | from .roi_pool import roi_pool 13 | from .smooth_l1_loss import smooth_l1_loss 14 | 15 | __all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool", "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", "FrozenBatchNorm2d"] 16 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import glob 3 | import os.path 4 | 5 | import torch 6 | 7 | try: 8 | from torch.utils.cpp_extension import load as load_ext 9 | from torch.utils.cpp_extension import CUDA_HOME 10 | except ImportError: 11 | raise ImportError("The cpp layer extensions requires PyTorch 0.4 or higher") 12 | 13 | 14 | def _load_C_extensions(): 15 | this_dir = os.path.dirname(os.path.abspath(__file__)) 16 | this_dir = os.path.dirname(this_dir) 17 | this_dir = os.path.join(this_dir, "csrc") 18 | 19 | main_file = glob.glob(os.path.join(this_dir, "*.cpp")) 20 | source_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) 21 | source_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) 22 | 23 | source = main_file + source_cpu 24 | 25 | extra_cflags = [] 26 | if torch.cuda.is_available() and CUDA_HOME is not None: 27 | source.extend(source_cuda) 28 | extra_cflags = ["-DWITH_CUDA"] 29 | source = [os.path.join(this_dir, s) for s in source] 30 | extra_include_paths = [this_dir] 31 | return load_ext( 32 | "torchvision", 33 | source, 34 | extra_cflags=extra_cflags, 35 | extra_include_paths=extra_include_paths, 36 | ) 37 | 38 | 39 | _C = _load_C_extensions() 40 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/batch_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | 12 | def __init__(self, n): 13 | super(FrozenBatchNorm2d, self).__init__() 14 | self.register_buffer("weight", torch.ones(n)) 15 | self.register_buffer("bias", torch.zeros(n)) 16 | self.register_buffer("running_mean", torch.zeros(n)) 17 | self.register_buffer("running_var", torch.ones(n)) 18 | 19 | def forward(self, x): 20 | scale = self.weight * self.running_var.rsqrt() 21 | bias = self.bias - self.running_mean * scale 22 | scale = scale.reshape(1, -1, 1, 1) 23 | bias = bias.reshape(1, -1, 1, 1) 24 | return x * scale + bias 25 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | helper class that supports empty tensors on some nn functions. 4 | 5 | Ideally, add support directly in PyTorch to empty tensors in 6 | those functions. 7 | 8 | This can be removed once https://github.com/pytorch/pytorch/issues/12013 9 | is implemented 10 | """ 11 | 12 | import math 13 | import torch 14 | from torch.nn.modules.utils import _ntuple 15 | 16 | 17 | class _NewEmptyTensorOp(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, x, new_shape): 20 | ctx.shape = x.shape 21 | return x.new_empty(new_shape) 22 | 23 | @staticmethod 24 | def backward(ctx, grad): 25 | shape = ctx.shape 26 | return _NewEmptyTensorOp.apply(grad, shape), None 27 | 28 | 29 | 30 | class Conv2d(torch.nn.Conv2d): 31 | def forward(self, x): 32 | if x.numel() > 0: 33 | return super(Conv2d, self).forward(x) 34 | # get output shape 35 | 36 | output_shape = [ 37 | (i + 2 * p - (di * (k - 1) + 1)) // d + 1 38 | for i, p, di, k, d in zip( 39 | x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride 40 | ) 41 | ] 42 | output_shape = [x.shape[0], self.weight.shape[0]] + output_shape 43 | return _NewEmptyTensorOp.apply(x, output_shape) 44 | 45 | 46 | class ConvTranspose2d(torch.nn.ConvTranspose2d): 47 | def forward(self, x): 48 | if x.numel() > 0: 49 | return super(ConvTranspose2d, self).forward(x) 50 | # get output shape 51 | 52 | output_shape = [ 53 | (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op 54 | for i, p, di, k, d, op in zip( 55 | x.shape[-2:], 56 | self.padding, 57 | self.dilation, 58 | self.kernel_size, 59 | self.stride, 60 | self.output_padding, 61 | ) 62 | ] 63 | output_shape = [x.shape[0], self.bias.shape[0]] + output_shape 64 | return _NewEmptyTensorOp.apply(x, output_shape) 65 | 66 | 67 | def interpolate( 68 | input, size=None, scale_factor=None, mode="nearest", align_corners=None 69 | ): 70 | if input.numel() > 0: 71 | return torch.nn.functional.interpolate( 72 | input, size, scale_factor, mode, align_corners 73 | ) 74 | 75 | def _check_size_scale_factor(dim): 76 | if size is None and scale_factor is None: 77 | raise ValueError("either size or scale_factor should be defined") 78 | if size is not None and scale_factor is not None: 79 | raise ValueError("only one of size or scale_factor should be defined") 80 | if ( 81 | scale_factor is not None 82 | and isinstance(scale_factor, tuple) 83 | and len(scale_factor) != dim 84 | ): 85 | raise ValueError( 86 | "scale_factor shape must match input shape. " 87 | "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) 88 | ) 89 | 90 | def _output_size(dim): 91 | _check_size_scale_factor(dim) 92 | if size is not None: 93 | return size 94 | scale_factors = _ntuple(dim)(scale_factor) 95 | # math.floor might return float in py2.7 96 | return [ 97 | int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) 98 | ] 99 | 100 | output_shape = tuple(_output_size(2)) 101 | output_shape = input.shape[:-2] + output_shape 102 | return _NewEmptyTensorOp.apply(input, output_shape) 103 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/nms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # from ._utils import _C 3 | from maskrcnn_benchmark import _C 4 | 5 | nms = _C.nms 6 | # nms.__doc__ = """ 7 | # This function performs Non-maximum suppresion""" 8 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/roi_align.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Function 5 | from torch.autograd.function import once_differentiable 6 | from torch.nn.modules.utils import _pair 7 | 8 | from maskrcnn_benchmark import _C 9 | 10 | 11 | class _ROIAlign(Function): 12 | @staticmethod 13 | def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): 14 | ctx.save_for_backward(roi) 15 | ctx.output_size = _pair(output_size) 16 | ctx.spatial_scale = spatial_scale 17 | ctx.sampling_ratio = sampling_ratio 18 | ctx.input_shape = input.size() 19 | output = _C.roi_align_forward( 20 | input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio 21 | ) 22 | return output 23 | 24 | @staticmethod 25 | @once_differentiable 26 | def backward(ctx, grad_output): 27 | rois, = ctx.saved_tensors 28 | output_size = ctx.output_size 29 | spatial_scale = ctx.spatial_scale 30 | sampling_ratio = ctx.sampling_ratio 31 | bs, ch, h, w = ctx.input_shape 32 | grad_input = _C.roi_align_backward( 33 | grad_output, 34 | rois, 35 | spatial_scale, 36 | output_size[0], 37 | output_size[1], 38 | bs, 39 | ch, 40 | h, 41 | w, 42 | sampling_ratio, 43 | ) 44 | return grad_input, None, None, None, None 45 | 46 | 47 | roi_align = _ROIAlign.apply 48 | 49 | 50 | class ROIAlign(nn.Module): 51 | def __init__(self, output_size, spatial_scale, sampling_ratio): 52 | super(ROIAlign, self).__init__() 53 | self.output_size = output_size 54 | self.spatial_scale = spatial_scale 55 | self.sampling_ratio = sampling_ratio 56 | 57 | def forward(self, input, rois): 58 | return roi_align( 59 | input, rois, self.output_size, self.spatial_scale, self.sampling_ratio 60 | ) 61 | 62 | def __repr__(self): 63 | tmpstr = self.__class__.__name__ + "(" 64 | tmpstr += "output_size=" + str(self.output_size) 65 | tmpstr += ", spatial_scale=" + str(self.spatial_scale) 66 | tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) 67 | tmpstr += ")" 68 | return tmpstr 69 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/roi_pool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Function 5 | from torch.autograd.function import once_differentiable 6 | from torch.nn.modules.utils import _pair 7 | 8 | from maskrcnn_benchmark import _C 9 | 10 | 11 | class _ROIPool(Function): 12 | @staticmethod 13 | def forward(ctx, input, roi, output_size, spatial_scale): 14 | ctx.output_size = _pair(output_size) 15 | ctx.spatial_scale = spatial_scale 16 | ctx.input_shape = input.size() 17 | output, argmax = _C.roi_pool_forward( 18 | input, roi, spatial_scale, output_size[0], output_size[1] 19 | ) 20 | ctx.save_for_backward(input, roi, argmax) 21 | return output 22 | 23 | @staticmethod 24 | @once_differentiable 25 | def backward(ctx, grad_output): 26 | input, rois, argmax = ctx.saved_tensors 27 | output_size = ctx.output_size 28 | spatial_scale = ctx.spatial_scale 29 | bs, ch, h, w = ctx.input_shape 30 | grad_input = _C.roi_pool_backward( 31 | grad_output, 32 | input, 33 | rois, 34 | argmax, 35 | spatial_scale, 36 | output_size[0], 37 | output_size[1], 38 | bs, 39 | ch, 40 | h, 41 | w, 42 | ) 43 | return grad_input, None, None, None 44 | 45 | 46 | roi_pool = _ROIPool.apply 47 | 48 | 49 | class ROIPool(nn.Module): 50 | def __init__(self, output_size, spatial_scale): 51 | super(ROIPool, self).__init__() 52 | self.output_size = output_size 53 | self.spatial_scale = spatial_scale 54 | 55 | def forward(self, input, rois): 56 | return roi_pool(input, rois, self.output_size, self.spatial_scale) 57 | 58 | def __repr__(self): 59 | tmpstr = self.__class__.__name__ + "(" 60 | tmpstr += "output_size=" + str(self.output_size) 61 | tmpstr += ", spatial_scale=" + str(self.spatial_scale) 62 | tmpstr += ")" 63 | return tmpstr 64 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/layers/smooth_l1_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | 5 | # TODO maybe push this to nn? 6 | def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): 7 | """ 8 | very similar to the smooth_l1_loss from pytorch, but with 9 | the extra beta parameter 10 | """ 11 | n = torch.abs(input - target) 12 | cond = n < beta 13 | loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) 14 | if size_average: 15 | return loss.mean() 16 | return loss.sum() 17 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/maskrcnn_benchmark/modeling/__init__.py -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .backbone import build_backbone 3 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | 4 | from torch import nn 5 | 6 | from . import fpn as fpn_module 7 | from . import resnet 8 | 9 | 10 | def build_resnet_backbone(cfg): 11 | body = resnet.ResNet(cfg) 12 | model = nn.Sequential(OrderedDict([("body", body)])) 13 | return model 14 | 15 | 16 | def build_resnet_fpn_backbone(cfg): 17 | body = resnet.ResNet(cfg) 18 | in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS 19 | out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS 20 | fpn = fpn_module.FPN( 21 | in_channels_list=[ 22 | in_channels_stage2, 23 | in_channels_stage2 * 2, 24 | in_channels_stage2 * 4, 25 | in_channels_stage2 * 8, 26 | ], 27 | out_channels=out_channels, 28 | top_blocks=fpn_module.LastLevelMaxPool(), 29 | ) 30 | model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) 31 | return model 32 | 33 | 34 | _BACKBONES = {"resnet": build_resnet_backbone, "resnet-fpn": build_resnet_fpn_backbone} 35 | 36 | 37 | def build_backbone(cfg): 38 | assert cfg.MODEL.BACKBONE.CONV_BODY.startswith( 39 | "R-" 40 | ), "Only ResNet and ResNeXt models are currently implemented" 41 | # Models using FPN end with "-FPN" 42 | if cfg.MODEL.BACKBONE.CONV_BODY.endswith("-FPN"): 43 | return build_resnet_fpn_backbone(cfg) 44 | return build_resnet_backbone(cfg) 45 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/backbone/fpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class FPN(nn.Module): 8 | """ 9 | Module that adds FPN on top of a list of feature maps. 10 | The feature maps are currently supposed to be in increasing depth 11 | order, and must be consecutive 12 | """ 13 | 14 | def __init__(self, in_channels_list, out_channels, top_blocks=None): 15 | """ 16 | Arguments: 17 | in_channels_list (list[int]): number of channels for each feature map that 18 | will be fed 19 | out_channels (int): number of channels of the FPN representation 20 | top_blocks (nn.Module or None): if provided, an extra operation will 21 | be performed on the output of the last (smallest resolution) 22 | FPN output, and the result will extend the result list 23 | """ 24 | super(FPN, self).__init__() 25 | self.inner_blocks = [] 26 | self.layer_blocks = [] 27 | for idx, in_channels in enumerate(in_channels_list, 1): 28 | inner_block = "fpn_inner{}".format(idx) 29 | layer_block = "fpn_layer{}".format(idx) 30 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1) 31 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, 1, 1) 32 | for module in [inner_block_module, layer_block_module]: 33 | # Caffe2 implementation uses XavierFill, which in fact 34 | # corresponds to kaiming_uniform_ in PyTorch 35 | nn.init.kaiming_uniform_(module.weight, a=1) 36 | nn.init.constant_(module.bias, 0) 37 | self.add_module(inner_block, inner_block_module) 38 | self.add_module(layer_block, layer_block_module) 39 | self.inner_blocks.append(inner_block) 40 | self.layer_blocks.append(layer_block) 41 | self.top_blocks = top_blocks 42 | 43 | def forward(self, x): 44 | """ 45 | Arguments: 46 | x (list[Tensor]): feature maps for each feature level. 47 | Returns: 48 | results (tuple[Tensor]): feature maps after FPN layers. 49 | They are ordered from highest resolution first. 50 | """ 51 | last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) 52 | results = [] 53 | results.append(getattr(self, self.layer_blocks[-1])(last_inner)) 54 | for feature, inner_block, layer_block in zip( 55 | x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] 56 | ): 57 | inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") 58 | inner_lateral = getattr(self, inner_block)(feature) 59 | # TODO use size instead of scale to make it robust to different sizes 60 | # inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:], 61 | # mode='bilinear', align_corners=False) 62 | last_inner = inner_lateral + inner_top_down 63 | results.insert(0, getattr(self, layer_block)(last_inner)) 64 | 65 | if self.top_blocks is not None: 66 | last_results = self.top_blocks(results[-1]) 67 | results.extend(last_results) 68 | 69 | return tuple(results) 70 | 71 | 72 | class LastLevelMaxPool(nn.Module): 73 | def forward(self, x): 74 | return [F.max_pool2d(x, 1, 2, 0)] 75 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | 5 | class BalancedPositiveNegativeSampler(object): 6 | """ 7 | This class samples batches, ensuring that they contain a fixed proportion of positives 8 | """ 9 | 10 | def __init__(self, batch_size_per_image, positive_fraction): 11 | """ 12 | Arguments: 13 | batch_size_per_image (int): number of elements to be selected per image 14 | positive_fraction (float): percentace of positive elements per batch 15 | """ 16 | self.batch_size_per_image = batch_size_per_image 17 | self.positive_fraction = positive_fraction 18 | 19 | def __call__(self, matched_idxs): 20 | """ 21 | Arguments: 22 | matched idxs: list of tensors containing -1, 0 or positive values. 23 | Each tensor corresponds to a specific image. 24 | -1 values are ignored, 0 are considered as negatives and > 0 as 25 | positives. 26 | 27 | Returns: 28 | pos_idx (list[tensor]) 29 | neg_idx (list[tensor]) 30 | 31 | Returns two lists of binary masks for each image. 32 | The first list contains the positive elements that were selected, 33 | and the second list the negative example. 34 | """ 35 | pos_idx = [] 36 | neg_idx = [] 37 | for matched_idxs_per_image in matched_idxs: 38 | positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) 39 | negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) 40 | 41 | num_pos = int(self.batch_size_per_image * self.positive_fraction) 42 | # protect against not enough positive examples 43 | num_pos = min(positive.numel(), num_pos) 44 | num_neg = self.batch_size_per_image - num_pos 45 | # protect against not enough negative examples 46 | num_neg = min(negative.numel(), num_neg) 47 | 48 | # randomly select positive and negative examples 49 | perm1 = torch.randperm(positive.numel())[:num_pos] 50 | perm2 = torch.randperm(negative.numel())[:num_neg] 51 | 52 | pos_idx_per_image = positive[perm1] 53 | neg_idx_per_image = negative[perm2] 54 | 55 | # create binary mask from indices 56 | pos_idx_per_image_mask = torch.zeros_like( 57 | matched_idxs_per_image, dtype=torch.bool 58 | ) 59 | neg_idx_per_image_mask = torch.zeros_like( 60 | matched_idxs_per_image, dtype=torch.bool 61 | ) 62 | pos_idx_per_image_mask[pos_idx_per_image] = 1 63 | neg_idx_per_image_mask[neg_idx_per_image] = 1 64 | 65 | pos_idx.append(pos_idx_per_image_mask) 66 | neg_idx.append(neg_idx_per_image_mask) 67 | 68 | return pos_idx, neg_idx 69 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import math 3 | 4 | import torch 5 | 6 | 7 | class BoxCoder(object): 8 | """ 9 | This class encodes and decodes a set of bounding boxes into 10 | the representation used for training the regressors. 11 | """ 12 | 13 | def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): 14 | """ 15 | Arguments: 16 | weights (4-element tuple) 17 | bbox_xform_clip (float) 18 | """ 19 | self.weights = weights 20 | self.bbox_xform_clip = bbox_xform_clip 21 | 22 | def encode(self, reference_boxes, proposals): 23 | """ 24 | Encode a set of proposals with respect to some 25 | reference boxes 26 | 27 | Arguments: 28 | reference_boxes (Tensor): reference boxes 29 | proposals (Tensor): boxes to be encoded 30 | """ 31 | 32 | TO_REMOVE = 1 # TODO remove 33 | ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE 34 | ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE 35 | ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths 36 | ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights 37 | 38 | gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE 39 | gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE 40 | gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths 41 | gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights 42 | 43 | wx, wy, ww, wh = self.weights 44 | targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths 45 | targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights 46 | targets_dw = ww * torch.log(gt_widths / ex_widths) 47 | targets_dh = wh * torch.log(gt_heights / ex_heights) 48 | 49 | targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) 50 | return targets 51 | 52 | def decode(self, rel_codes, boxes): 53 | """ 54 | From a set of original boxes and encoded relative box offsets, 55 | get the decoded boxes. 56 | 57 | Arguments: 58 | rel_codes (Tensor): encoded boxes 59 | boxes (Tensor): reference boxes. 60 | """ 61 | 62 | boxes = boxes.to(rel_codes.dtype) 63 | 64 | TO_REMOVE = 1 # TODO remove 65 | widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE 66 | heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE 67 | ctr_x = boxes[:, 0] + 0.5 * widths 68 | ctr_y = boxes[:, 1] + 0.5 * heights 69 | 70 | wx, wy, ww, wh = self.weights 71 | dx = rel_codes[:, 0::4] / wx 72 | dy = rel_codes[:, 1::4] / wy 73 | dw = rel_codes[:, 2::4] / ww 74 | dh = rel_codes[:, 3::4] / wh 75 | 76 | # Prevent sending too large values into torch.exp() 77 | dw = torch.clamp(dw, max=self.bbox_xform_clip) 78 | dh = torch.clamp(dh, max=self.bbox_xform_clip) 79 | 80 | pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] 81 | pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] 82 | pred_w = torch.exp(dw) * widths[:, None] 83 | pred_h = torch.exp(dh) * heights[:, None] 84 | 85 | pred_boxes = torch.zeros_like(rel_codes) 86 | # x1 87 | pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w 88 | # y1 89 | pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h 90 | # x2 (note: "- 1" is correct; don't be fooled by the asymmetry) 91 | pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1 92 | # y2 (note: "- 1" is correct; don't be fooled by the asymmetry) 93 | pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 94 | 95 | return pred_boxes 96 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/detector/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .detectors import build_detection_model 3 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/detector/detectors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .generalized_rcnn import GeneralizedRCNN 3 | 4 | 5 | _DETECTION_META_ARCHITECTURES = {"GeneralizedRCNN": GeneralizedRCNN} 6 | 7 | 8 | def build_detection_model(cfg): 9 | meta_arch = _DETECTION_META_ARCHITECTURES[cfg.MODEL.META_ARCHITECTURE] 10 | return meta_arch(cfg) 11 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/detector/generalized_rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | Implements the Generalized R-CNN framework 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from maskrcnn_benchmark.structures.image_list import to_image_list 10 | 11 | from ..backbone import build_backbone 12 | from ..rpn.rpn import build_rpn 13 | from ..roi_heads.roi_heads import build_roi_heads 14 | 15 | 16 | class GeneralizedRCNN(nn.Module): 17 | """ 18 | Main class for Generalized R-CNN. Currently supports boxes and masks. 19 | It consists of three main parts: 20 | - backbone 21 | = rpn 22 | - heads: takes the features + the proposals from the RPN and computes 23 | detections / masks from it. 24 | """ 25 | 26 | def __init__(self, cfg): 27 | super(GeneralizedRCNN, self).__init__() 28 | 29 | self.backbone = build_backbone(cfg) 30 | self.rpn = build_rpn(cfg) 31 | self.roi_heads = build_roi_heads(cfg) 32 | 33 | def forward(self, images, targets=None): 34 | """ 35 | Arguments: 36 | images (list[Tensor] or ImageList): images to be processed 37 | targets (list[BoxList]): ground-truth boxes present in the image (optional) 38 | 39 | Returns: 40 | result (list[BoxList] or dict[Tensor]): the output from the model. 41 | During training, it returns a dict[Tensor] which contains the losses. 42 | During testing, it returns list[BoxList] contains additional fields 43 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 44 | 45 | """ 46 | if self.training and targets is None: 47 | raise ValueError("In training mode, targets should be passed") 48 | images = to_image_list(images) 49 | features = self.backbone(images.tensors) 50 | proposals, proposal_losses = self.rpn(images, features, targets) 51 | if self.roi_heads: 52 | x, result, detector_losses = self.roi_heads(features, proposals, targets) 53 | else: 54 | # RPN-only models don't have roi_heads 55 | x = features 56 | result = proposals 57 | detector_losses = {} 58 | 59 | if self.training: 60 | losses = {} 61 | losses.update(detector_losses) 62 | losses.update(proposal_losses) 63 | return losses 64 | 65 | return result 66 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | 5 | class Matcher(object): 6 | """ 7 | This class assigns to each predicted "element" (e.g., a box) a ground-truth 8 | element. Each predicted element will have exactly zero or one matches; each 9 | ground-truth element may be assigned to zero or more predicted elements. 10 | 11 | Matching is based on the MxN match_quality_matrix, that characterizes how well 12 | each (ground-truth, predicted)-pair match. For example, if the elements are 13 | boxes, the matrix may contain box IoU overlap values. 14 | 15 | The matcher returns a tensor of size N containing the index of the ground-truth 16 | element m that matches to prediction n. If there is no match, a negative value 17 | is returned. 18 | """ 19 | 20 | BELOW_LOW_THRESHOLD = -1 21 | BETWEEN_THRESHOLDS = -2 22 | 23 | def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): 24 | """ 25 | Args: 26 | high_threshold (float): quality values greater than or equal to 27 | this value are candidate matches. 28 | low_threshold (float): a lower quality threshold used to stratify 29 | matches into three levels: 30 | 1) matches >= high_threshold 31 | 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 32 | 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) 33 | allow_low_quality_matches (bool): if True, produce additional matches 34 | for predictions that have only low-quality match candidates. See 35 | set_low_quality_matches_ for more details. 36 | """ 37 | assert low_threshold <= high_threshold 38 | self.high_threshold = high_threshold 39 | self.low_threshold = low_threshold 40 | self.allow_low_quality_matches = allow_low_quality_matches 41 | 42 | def __call__(self, match_quality_matrix): 43 | """ 44 | Args: 45 | match_quality_matrix (Tensor[float]): an MxN tensor, containing the 46 | pairwise quality between M ground-truth elements and N predicted elements. 47 | 48 | Returns: 49 | matches (Tensor[int64]): an N tensor where N[i] is a matched gt in 50 | [0, M - 1] or a negative value indicating that prediction i could not 51 | be matched. 52 | """ 53 | if match_quality_matrix.numel() == 0: 54 | # handle empty case 55 | device = match_quality_matrix.device 56 | return torch.empty((0,), dtype=torch.int64, device=device) 57 | 58 | # match_quality_matrix is M (gt) x N (predicted) 59 | # Max over gt elements (dim 0) to find best gt candidate for each prediction 60 | matched_vals, matches = match_quality_matrix.max(dim=0) 61 | if self.allow_low_quality_matches: 62 | all_matches = matches.clone() 63 | 64 | # Assign candidate matches with low quality to negative (unassigned) values 65 | below_low_threshold = matched_vals < self.low_threshold 66 | between_thresholds = (matched_vals >= self.low_threshold) & ( 67 | matched_vals < self.high_threshold 68 | ) 69 | matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD 70 | matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS 71 | 72 | if self.allow_low_quality_matches: 73 | self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) 74 | 75 | return matches 76 | 77 | def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): 78 | """ 79 | Produce additional matches for predictions that have only low-quality matches. 80 | Specifically, for each ground-truth find the set of predictions that have 81 | maximum overlap with it (including ties); for each prediction in that set, if 82 | it is unmatched, then match it to the ground-truth with which it has the highest 83 | quality value. 84 | """ 85 | # For each gt, find the prediction with which it has highest quality 86 | highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) 87 | # Find highest quality match available, even if it is low, including ties 88 | gt_pred_pairs_of_highest_quality = torch.nonzero( 89 | match_quality_matrix == highest_quality_foreach_gt[:, None] 90 | ) 91 | # Example gt_pred_pairs_of_highest_quality: 92 | # tensor([[ 0, 39796], 93 | # [ 1, 32055], 94 | # [ 1, 32070], 95 | # [ 2, 39190], 96 | # [ 2, 40255], 97 | # [ 3, 40390], 98 | # [ 3, 41455], 99 | # [ 4, 45470], 100 | # [ 5, 45325], 101 | # [ 5, 46390]]) 102 | # Each row is a (gt index, prediction index) 103 | # Note how gt items 1, 2, 3, and 5 each have two ties 104 | 105 | pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] 106 | matches[pred_inds_to_update] = all_matches[pred_inds_to_update] 107 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/poolers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from maskrcnn_benchmark.layers import ROIAlign 8 | 9 | from .utils import cat 10 | 11 | 12 | class LevelMapper(object): 13 | """Determine which FPN level each RoI in a set of RoIs should map to based 14 | on the heuristic in the FPN paper. 15 | """ 16 | 17 | def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): 18 | """ 19 | Arguments: 20 | k_min (int) 21 | k_max (int) 22 | canonical_scale (int) 23 | canonical_level (int) 24 | eps (float) 25 | """ 26 | self.k_min = k_min 27 | self.k_max = k_max 28 | self.s0 = canonical_scale 29 | self.lvl0 = canonical_level 30 | self.eps = eps 31 | 32 | def __call__(self, boxlists): 33 | """ 34 | Arguments: 35 | boxlists (list[BoxList]) 36 | """ 37 | # Compute level ids 38 | s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists])) 39 | 40 | # Eqn.(1) in FPN paper 41 | target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) 42 | target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) 43 | return target_lvls.to(torch.int64) - self.k_min 44 | 45 | 46 | class Pooler(nn.Module): 47 | """ 48 | Pooler for Detection with or without FPN. 49 | It currently hard-code ROIAlign in the implementation, 50 | but that can be made more generic later on. 51 | Also, the requirement of passing the scales is not strictly necessary, as they 52 | can be inferred from the size of the feature map / size of original image, 53 | which is available thanks to the BoxList. 54 | """ 55 | 56 | def __init__(self, output_size, scales, sampling_ratio): 57 | """ 58 | Arguments: 59 | output_size (list[tuple[int]] or list[int]): output size for the pooled region 60 | scales (list[flaot]): scales for each Pooler 61 | sampling_ratio (int): sampling ratio for ROIAlign 62 | """ 63 | super(Pooler, self).__init__() 64 | poolers = [] 65 | for scale in scales: 66 | poolers.append( 67 | ROIAlign( 68 | output_size, spatial_scale=scale, sampling_ratio=sampling_ratio 69 | ) 70 | ) 71 | self.poolers = nn.ModuleList(poolers) 72 | self.output_size = output_size 73 | # get the levels in the feature map by leveraging the fact that the network always 74 | # downsamples by a factor of 2 at each level. 75 | lvl_min = -math.log2(scales[0]) 76 | lvl_max = -math.log2(scales[-1]) 77 | self.map_levels = LevelMapper(lvl_min, lvl_max) 78 | 79 | def convert_to_roi_format(self, boxes): 80 | concat_boxes = cat([b.bbox for b in boxes], dim=0) 81 | device, dtype = concat_boxes.device, concat_boxes.dtype 82 | ids = cat( 83 | [ 84 | torch.full((len(b), 1), i, dtype=dtype, device=device) 85 | for i, b in enumerate(boxes) 86 | ], 87 | dim=0, 88 | ) 89 | rois = torch.cat([ids, concat_boxes], dim=1) 90 | return rois 91 | 92 | def forward(self, x, boxes): 93 | """ 94 | Arguments: 95 | x (list[Tensor]): feature maps for each level 96 | boxes (list[BoxList]): boxes to be used to perform the pooling operation. 97 | Returns: 98 | result (Tensor) 99 | """ 100 | num_levels = len(self.poolers) 101 | rois = self.convert_to_roi_format(boxes) 102 | if num_levels == 1: 103 | return self.poolers[0](x[0], rois) 104 | 105 | levels = self.map_levels(boxes) 106 | 107 | num_rois = len(rois) 108 | num_channels = x[0].shape[1] 109 | output_size_h = self.output_size[0] 110 | output_size_w = self.output_size[1] 111 | 112 | dtype, device = x[0].dtype, x[0].device 113 | result = torch.zeros( 114 | (num_rois, num_channels, output_size_h, output_size_w), 115 | dtype=dtype, 116 | device=device, 117 | ) 118 | for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): 119 | idx_in_level = torch.nonzero(levels == level).squeeze(1) 120 | rois_per_level = rois[idx_in_level] 121 | result[idx_in_level] = pooler(per_level_feature, rois_per_level) 122 | 123 | return result 124 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/maskrcnn_benchmark/modeling/roi_heads/__init__.py -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch import nn 4 | 5 | from .roi_box_feature_extractors import make_roi_box_feature_extractor 6 | from .roi_box_predictors import make_roi_box_predictor 7 | from .inference import make_roi_box_post_processor 8 | from .loss import make_roi_box_loss_evaluator 9 | 10 | 11 | class ROIBoxHead(torch.nn.Module): 12 | """ 13 | Generic Box Head class. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super(ROIBoxHead, self).__init__() 18 | self.feature_extractor = make_roi_box_feature_extractor(cfg) 19 | self.predictor = make_roi_box_predictor(cfg) 20 | self.post_processor = make_roi_box_post_processor(cfg) 21 | self.loss_evaluator = make_roi_box_loss_evaluator(cfg) 22 | 23 | def forward(self, features, proposals, targets=None): 24 | """ 25 | Arguments: 26 | features (list[Tensor]): feature-maps from possibly several levels 27 | proposals (list[BoxList]): proposal boxes 28 | targets (list[BoxList], optional): the ground-truth targets. 29 | 30 | Returns: 31 | x (Tensor): the result of the feature extractor 32 | proposals (list[BoxList]): during training, the subsampled proposals 33 | are returned. During testing, the predicted boxlists are returned 34 | losses (dict[Tensor]): During training, returns the losses for the 35 | head. During testing, returns an empty dict. 36 | """ 37 | 38 | if self.training: 39 | # Faster R-CNN subsamples during training the proposals with a fixed 40 | # positive / negative ratio 41 | with torch.no_grad(): 42 | proposals = self.loss_evaluator.subsample(proposals, targets) 43 | 44 | # extract features that will be fed to the final classifier. The 45 | # feature_extractor generally corresponds to the pooler + heads 46 | x = self.feature_extractor(features, proposals) 47 | # final classifier that converts the features into predictions 48 | class_logits, box_regression = self.predictor(x) 49 | 50 | if not self.training: 51 | result = self.post_processor((class_logits, box_regression), proposals) 52 | return x, result, {} 53 | 54 | loss_classifier, loss_box_reg = self.loss_evaluator( 55 | [class_logits], [box_regression] 56 | ) 57 | return ( 58 | x, 59 | proposals, 60 | dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg), 61 | ) 62 | 63 | 64 | def build_roi_box_head(cfg): 65 | """ 66 | Constructs a new box head. 67 | By default, uses ROIBoxHead, but if it turns out not to be enough, just register a new class 68 | and make it a parameter in the config 69 | """ 70 | return ROIBoxHead(cfg) 71 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from maskrcnn_benchmark.structures.bounding_box import BoxList 7 | from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms 8 | from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist 9 | from maskrcnn_benchmark.modeling.box_coder import BoxCoder 10 | 11 | 12 | class PostProcessor(nn.Module): 13 | """ 14 | From a set of classification scores, box regression and proposals, 15 | computes the post-processed boxes, and applies NMS to obtain the 16 | final results 17 | """ 18 | 19 | def __init__( 20 | self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None 21 | ): 22 | """ 23 | Arguments: 24 | score_thresh (float) 25 | nms (float) 26 | detections_per_img (int) 27 | box_coder (BoxCoder) 28 | """ 29 | super(PostProcessor, self).__init__() 30 | self.score_thresh = score_thresh 31 | self.nms = nms 32 | self.detections_per_img = detections_per_img 33 | if box_coder is None: 34 | box_coder = BoxCoder(weights=(10., 10., 5., 5.)) 35 | self.box_coder = box_coder 36 | 37 | def forward(self, x, boxes): 38 | """ 39 | Arguments: 40 | x (tuple[tensor, tensor]): x contains the class logits 41 | and the box_regression from the model. 42 | boxes (list[BoxList]): bounding boxes that are used as 43 | reference, one for ech image 44 | 45 | Returns: 46 | results (list[BoxList]): one BoxList for each image, containing 47 | the extra fields labels and scores 48 | """ 49 | class_logits, box_regression = x 50 | class_prob = F.softmax(class_logits, -1) 51 | 52 | # TODO think about a representation of batch of boxes 53 | image_shapes = [box.size for box in boxes] 54 | boxes_per_image = [len(box) for box in boxes] 55 | concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) 56 | 57 | proposals = self.box_coder.decode( 58 | box_regression.view(sum(boxes_per_image), -1), concat_boxes 59 | ) 60 | 61 | num_classes = class_prob.shape[1] 62 | 63 | proposals = proposals.split(boxes_per_image, dim=0) 64 | class_prob = class_prob.split(boxes_per_image, dim=0) 65 | 66 | results = [] 67 | for prob, boxes_per_img, image_shape in zip( 68 | class_prob, proposals, image_shapes 69 | ): 70 | boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) 71 | boxlist = boxlist.clip_to_image(remove_empty=False) 72 | boxlist = self.filter_results(boxlist, num_classes) 73 | results.append(boxlist) 74 | return results 75 | 76 | def prepare_boxlist(self, boxes, scores, image_shape): 77 | """ 78 | Returns BoxList from `boxes` and adds probability scores information 79 | as an extra field 80 | `boxes` has shape (#detections, 4 * #classes), where each row represents 81 | a list of predicted bounding boxes for each of the object classes in the 82 | dataset (including the background class). The detections in each row 83 | originate from the same object proposal. 84 | `scores` has shape (#detection, #classes), where each row represents a list 85 | of object detection confidence scores for each of the object classes in the 86 | dataset (including the background class). `scores[i, j]`` corresponds to the 87 | box at `boxes[i, j * 4:(j + 1) * 4]`. 88 | """ 89 | boxes = boxes.reshape(-1, 4) 90 | scores = scores.reshape(-1) 91 | boxlist = BoxList(boxes, image_shape, mode="xyxy") 92 | boxlist.add_field("scores", scores) 93 | return boxlist 94 | 95 | def filter_results(self, boxlist, num_classes): 96 | """Returns bounding-box detection results by thresholding on scores and 97 | applying non-maximum suppression (NMS). 98 | """ 99 | # unwrap the boxlist to avoid additional overhead. 100 | # if we had multi-class NMS, we could perform this directly on the boxlist 101 | boxes = boxlist.bbox.reshape(-1, num_classes * 4) 102 | scores = boxlist.get_field("scores").reshape(-1, num_classes) 103 | 104 | device = scores.device 105 | result = [] 106 | # Apply threshold on detection probabilities and apply NMS 107 | # Skip j = 0, because it's the background class 108 | inds_all = scores > self.score_thresh 109 | for j in range(1, num_classes): 110 | inds = inds_all[:, j].nonzero().squeeze(1) 111 | scores_j = scores[inds, j] 112 | boxes_j = boxes[inds, j * 4 : (j + 1) * 4] 113 | boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") 114 | boxlist_for_class.add_field("scores", scores_j) 115 | boxlist_for_class = boxlist_nms( 116 | boxlist_for_class, self.nms, score_field="scores" 117 | ) 118 | num_labels = len(boxlist_for_class) 119 | boxlist_for_class.add_field( 120 | "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device) 121 | ) 122 | result.append(boxlist_for_class) 123 | 124 | result = cat_boxlist(result) 125 | number_of_detections = len(result) 126 | 127 | # Limit to max_per_image detections **over all classes** 128 | if number_of_detections > self.detections_per_img > 0: 129 | cls_scores = result.get_field("scores") 130 | image_thresh, _ = torch.kthvalue( 131 | cls_scores.cpu(), number_of_detections - self.detections_per_img + 1 132 | ) 133 | keep = cls_scores >= image_thresh.item() 134 | keep = torch.nonzero(keep).squeeze(1) 135 | result = result[keep] 136 | return result 137 | 138 | 139 | def make_roi_box_post_processor(cfg): 140 | use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN 141 | 142 | bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS 143 | box_coder = BoxCoder(weights=bbox_reg_weights) 144 | 145 | score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH 146 | nms_thresh = cfg.MODEL.ROI_HEADS.NMS 147 | detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG 148 | 149 | postprocessor = PostProcessor( 150 | score_thresh, nms_thresh, detections_per_img, box_coder 151 | ) 152 | return postprocessor 153 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from maskrcnn_benchmark.layers import smooth_l1_loss 6 | from maskrcnn_benchmark.modeling.box_coder import BoxCoder 7 | from maskrcnn_benchmark.modeling.matcher import Matcher 8 | from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou 9 | from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import ( 10 | BalancedPositiveNegativeSampler 11 | ) 12 | from maskrcnn_benchmark.modeling.utils import cat 13 | 14 | 15 | class FastRCNNLossComputation(object): 16 | """ 17 | Computes the loss for Faster R-CNN. 18 | Also supports FPN 19 | """ 20 | 21 | def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): 22 | """ 23 | Arguments: 24 | proposal_matcher (Matcher) 25 | fg_bg_sampler (BalancedPositiveNegativeSampler) 26 | box_coder (BoxCoder) 27 | """ 28 | self.proposal_matcher = proposal_matcher 29 | self.fg_bg_sampler = fg_bg_sampler 30 | self.box_coder = box_coder 31 | 32 | def match_targets_to_proposals(self, proposal, target): 33 | match_quality_matrix = boxlist_iou(target, proposal) 34 | matched_idxs = self.proposal_matcher(match_quality_matrix) 35 | # Fast RCNN only need "labels" field for selecting the targets 36 | target = target.copy_with_fields("labels") 37 | # get the targets corresponding GT for each proposal 38 | # NB: need to clamp the indices because we can have a single 39 | # GT in the image, and matched_idxs can be -2, which goes 40 | # out of bounds 41 | matched_targets = target[matched_idxs.clamp(min=0)] 42 | matched_targets.add_field("matched_idxs", matched_idxs) 43 | return matched_targets 44 | 45 | def prepare_targets(self, proposals, targets): 46 | labels = [] 47 | regression_targets = [] 48 | for proposals_per_image, targets_per_image in zip(proposals, targets): 49 | matched_targets = self.match_targets_to_proposals( 50 | proposals_per_image, targets_per_image 51 | ) 52 | matched_idxs = matched_targets.get_field("matched_idxs") 53 | 54 | labels_per_image = matched_targets.get_field("labels") 55 | labels_per_image = labels_per_image.to(dtype=torch.int64) 56 | 57 | # Label background (below the low threshold) 58 | bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD 59 | labels_per_image[bg_inds] = 0 60 | 61 | # Label ignore proposals (between low and high thresholds) 62 | ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS 63 | labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler 64 | 65 | # compute regression targets 66 | regression_targets_per_image = self.box_coder.encode( 67 | matched_targets.bbox, proposals_per_image.bbox 68 | ) 69 | 70 | labels.append(labels_per_image) 71 | regression_targets.append(regression_targets_per_image) 72 | 73 | return labels, regression_targets 74 | 75 | def subsample(self, proposals, targets): 76 | """ 77 | This method performs the positive/negative sampling, and return 78 | the sampled proposals. 79 | Note: this function keeps a state. 80 | 81 | Arguments: 82 | proposals (list[BoxList]) 83 | targets (list[BoxList]) 84 | """ 85 | 86 | labels, regression_targets = self.prepare_targets(proposals, targets) 87 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 88 | 89 | proposals = list(proposals) 90 | # add corresponding label and regression_targets information to the bounding boxes 91 | for labels_per_image, regression_targets_per_image, proposals_per_image in zip( 92 | labels, regression_targets, proposals 93 | ): 94 | proposals_per_image.add_field("labels", labels_per_image) 95 | proposals_per_image.add_field( 96 | "regression_targets", regression_targets_per_image 97 | ) 98 | 99 | # distributed sampled proposals, that were obtained on all feature maps 100 | # concatenated via the fg_bg_sampler, into individual feature map levels 101 | for img_idx, (pos_inds_img, neg_inds_img) in enumerate( 102 | zip(sampled_pos_inds, sampled_neg_inds) 103 | ): 104 | img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) 105 | proposals_per_image = proposals[img_idx][img_sampled_inds] 106 | proposals[img_idx] = proposals_per_image 107 | 108 | self._proposals = proposals 109 | return proposals 110 | 111 | def __call__(self, class_logits, box_regression): 112 | """ 113 | Computes the loss for Faster R-CNN. 114 | This requires that the subsample method has been called beforehand. 115 | 116 | Arguments: 117 | class_logits (list[Tensor]) 118 | box_regression (list[Tensor]) 119 | 120 | Returns: 121 | classification_loss (Tensor) 122 | box_loss (Tensor) 123 | """ 124 | 125 | class_logits = cat(class_logits, dim=0) 126 | box_regression = cat(box_regression, dim=0) 127 | device = class_logits.device 128 | 129 | if not hasattr(self, "_proposals"): 130 | raise RuntimeError("subsample needs to be called before") 131 | 132 | proposals = self._proposals 133 | 134 | labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) 135 | regression_targets = cat( 136 | [proposal.get_field("regression_targets") for proposal in proposals], dim=0 137 | ) 138 | 139 | classification_loss = F.cross_entropy(class_logits, labels) 140 | 141 | # get indices that correspond to the regression targets for 142 | # the corresponding ground truth labels, to be used with 143 | # advanced indexing 144 | sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) 145 | labels_pos = labels[sampled_pos_inds_subset] 146 | map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device) 147 | 148 | box_loss = smooth_l1_loss( 149 | box_regression[sampled_pos_inds_subset[:, None], map_inds], 150 | regression_targets[sampled_pos_inds_subset], 151 | size_average=False, 152 | beta=1, 153 | ) 154 | box_loss = box_loss / labels.numel() 155 | 156 | return classification_loss, box_loss 157 | 158 | 159 | def make_roi_box_loss_evaluator(cfg): 160 | matcher = Matcher( 161 | cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, 162 | cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, 163 | allow_low_quality_matches=False, 164 | ) 165 | 166 | bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS 167 | box_coder = BoxCoder(weights=bbox_reg_weights) 168 | 169 | fg_bg_sampler = BalancedPositiveNegativeSampler( 170 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION 171 | ) 172 | 173 | loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder) 174 | 175 | return loss_evaluator 176 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from maskrcnn_benchmark.modeling.backbone import resnet 6 | from maskrcnn_benchmark.modeling.poolers import Pooler 7 | 8 | 9 | class ResNet50Conv5ROIFeatureExtractor(nn.Module): 10 | def __init__(self, config): 11 | super(ResNet50Conv5ROIFeatureExtractor, self).__init__() 12 | 13 | resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION 14 | scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES 15 | sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO 16 | pooler = Pooler( 17 | output_size=(resolution, resolution), 18 | scales=scales, 19 | sampling_ratio=sampling_ratio, 20 | ) 21 | 22 | stage = resnet.StageSpec(index=4, block_count=3, return_features=False) 23 | head = resnet.ResNetHead( 24 | block_module=config.MODEL.RESNETS.TRANS_FUNC, 25 | stages=(stage,), 26 | num_groups=config.MODEL.RESNETS.NUM_GROUPS, 27 | width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP, 28 | stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1, 29 | stride_init=None, 30 | res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS, 31 | ) 32 | 33 | self.pooler = pooler 34 | self.head = head 35 | 36 | def forward(self, x, proposals): 37 | x = self.pooler(x, proposals) 38 | x = self.head(x) 39 | return x 40 | 41 | 42 | class FPN2MLPFeatureExtractor(nn.Module): 43 | """ 44 | Heads for FPN for classification 45 | """ 46 | 47 | def __init__(self, cfg): 48 | super(FPN2MLPFeatureExtractor, self).__init__() 49 | 50 | resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION 51 | scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES 52 | sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO 53 | pooler = Pooler( 54 | output_size=(resolution, resolution), 55 | scales=scales, 56 | sampling_ratio=sampling_ratio, 57 | ) 58 | input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution ** 2 59 | representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM 60 | self.pooler = pooler 61 | self.fc6 = nn.Linear(input_size, representation_size) 62 | self.fc7 = nn.Linear(representation_size, representation_size) 63 | 64 | for l in [self.fc6, self.fc7]: 65 | # Caffe2 implementation uses XavierFill, which in fact 66 | # corresponds to kaiming_uniform_ in PyTorch 67 | nn.init.kaiming_uniform_(l.weight, a=1) 68 | nn.init.constant_(l.bias, 0) 69 | 70 | def forward(self, x, proposals): 71 | x = self.pooler(x, proposals) 72 | x = x.view(x.size(0), -1) 73 | 74 | x = F.relu(self.fc6(x)) 75 | x = F.relu(self.fc7(x)) 76 | 77 | return x 78 | 79 | 80 | _ROI_BOX_FEATURE_EXTRACTORS = { 81 | "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, 82 | "FPN2MLPFeatureExtractor": FPN2MLPFeatureExtractor, 83 | } 84 | 85 | 86 | def make_roi_box_feature_extractor(cfg): 87 | func = _ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR] 88 | return func(cfg) 89 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch import nn 3 | 4 | 5 | class FastRCNNPredictor(nn.Module): 6 | def __init__(self, config, pretrained=None): 7 | super(FastRCNNPredictor, self).__init__() 8 | 9 | stage_index = 4 10 | stage2_relative_factor = 2 ** (stage_index - 1) 11 | res2_out_channels = config.MODEL.RESNETS.RES2_OUT_CHANNELS 12 | num_inputs = res2_out_channels * stage2_relative_factor 13 | 14 | num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES 15 | self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7) 16 | self.cls_score = nn.Linear(num_inputs, num_classes) 17 | self.bbox_pred = nn.Linear(num_inputs, num_classes * 4) 18 | 19 | nn.init.normal_(self.cls_score.weight, mean=0, std=0.01) 20 | nn.init.constant_(self.cls_score.bias, 0) 21 | 22 | nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001) 23 | nn.init.constant_(self.bbox_pred.bias, 0) 24 | 25 | def forward(self, x): 26 | x = self.avgpool(x) 27 | x = x.view(x.size(0), -1) 28 | cls_logit = self.cls_score(x) 29 | bbox_pred = self.bbox_pred(x) 30 | return cls_logit, bbox_pred 31 | 32 | 33 | class FPNPredictor(nn.Module): 34 | def __init__(self, cfg): 35 | super(FPNPredictor, self).__init__() 36 | num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES 37 | representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM 38 | 39 | self.cls_score = nn.Linear(representation_size, num_classes) 40 | self.bbox_pred = nn.Linear(representation_size, num_classes * 4) 41 | 42 | nn.init.normal_(self.cls_score.weight, std=0.01) 43 | nn.init.normal_(self.bbox_pred.weight, std=0.001) 44 | for l in [self.cls_score, self.bbox_pred]: 45 | nn.init.constant_(l.bias, 0) 46 | 47 | def forward(self, x): 48 | scores = self.cls_score(x) 49 | bbox_deltas = self.bbox_pred(x) 50 | 51 | return scores, bbox_deltas 52 | 53 | 54 | _ROI_BOX_PREDICTOR = { 55 | "FastRCNNPredictor": FastRCNNPredictor, 56 | "FPNPredictor": FPNPredictor, 57 | } 58 | 59 | 60 | def make_roi_box_predictor(cfg): 61 | func = _ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR] 62 | return func(cfg) 63 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor 6 | from maskrcnn_benchmark.modeling.poolers import Pooler 7 | from maskrcnn_benchmark.layers import Conv2d 8 | 9 | 10 | class MaskRCNNFPNFeatureExtractor(nn.Module): 11 | """ 12 | Heads for FPN for classification 13 | """ 14 | 15 | def __init__(self, cfg): 16 | """ 17 | Arguments: 18 | num_classes (int): number of output classes 19 | input_size (int): number of channels of the input once it's flattened 20 | representation_size (int): size of the intermediate representation 21 | """ 22 | super(MaskRCNNFPNFeatureExtractor, self).__init__() 23 | 24 | # resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION 25 | if cfg.MODEL.CHAR_MASK_ON: 26 | resolution_h = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H 27 | resolution_w = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W 28 | else: 29 | resolution_h = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION 30 | resolution_w = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION 31 | scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES 32 | sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO 33 | pooler = Pooler( 34 | output_size=(resolution_h, resolution_w), 35 | scales=scales, 36 | sampling_ratio=sampling_ratio, 37 | ) 38 | input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS 39 | self.pooler = pooler 40 | 41 | layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS 42 | 43 | next_feature = input_size 44 | self.blocks = [] 45 | for layer_idx, layer_features in enumerate(layers, 1): 46 | layer_name = "mask_fcn{}".format(layer_idx) 47 | module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) 48 | # Caffe2 implementation uses MSRAFill, which in fact 49 | # corresponds to kaiming_normal_ in PyTorch 50 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 51 | nn.init.constant_(module.bias, 0) 52 | self.add_module(layer_name, module) 53 | next_feature = layer_features 54 | self.blocks.append(layer_name) 55 | 56 | def forward(self, x, proposals): 57 | x = self.pooler(x, proposals) 58 | 59 | for layer_name in self.blocks: 60 | x = F.relu(getattr(self, layer_name)(x)) 61 | 62 | return x 63 | 64 | 65 | _ROI_MASK_FEATURE_EXTRACTORS = { 66 | "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, 67 | "MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor, 68 | } 69 | 70 | 71 | def make_roi_mask_feature_extractor(cfg): 72 | func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR] 73 | return func(cfg) 74 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from maskrcnn_benchmark.layers import Conv2d 6 | from maskrcnn_benchmark.layers import ConvTranspose2d 7 | from .roi_seq_predictors import make_roi_seq_predictor 8 | 9 | class MaskRCNNC4Predictor(nn.Module): 10 | def __init__(self, cfg): 11 | super(MaskRCNNC4Predictor, self).__init__() 12 | num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES 13 | dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] 14 | 15 | if cfg.MODEL.ROI_HEADS.USE_FPN: 16 | num_inputs = dim_reduced 17 | else: 18 | stage_index = 4 19 | stage2_relative_factor = 2 ** (stage_index - 1) 20 | res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS 21 | num_inputs = res2_out_channels * stage2_relative_factor 22 | 23 | self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) 24 | self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) 25 | 26 | for name, param in self.named_parameters(): 27 | if "bias" in name: 28 | nn.init.constant_(param, 0) 29 | elif "weight" in name: 30 | # Caffe2 implementation uses MSRAFill, which in fact 31 | # corresponds to kaiming_normal_ in PyTorch 32 | nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") 33 | 34 | def forward(self, x): 35 | x = F.relu(self.conv5_mask(x)) 36 | return self.mask_fcn_logits(x) 37 | 38 | class CharMaskRCNNC4Predictor(nn.Module): 39 | def __init__(self, cfg): 40 | super(CharMaskRCNNC4Predictor, self).__init__() 41 | # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES 42 | num_classes = 1 43 | char_num_classes = cfg.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES 44 | dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] 45 | 46 | if cfg.MODEL.ROI_HEADS.USE_FPN: 47 | num_inputs = dim_reduced 48 | else: 49 | stage_index = 4 50 | stage2_relative_factor = 2 ** (stage_index - 1) 51 | res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS 52 | num_inputs = res2_out_channels * stage2_relative_factor 53 | 54 | self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) 55 | if cfg.MODEL.CHAR_MASK_ON: 56 | self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) 57 | self.char_mask_fcn_logits = Conv2d(dim_reduced, char_num_classes, 1, 1, 0) 58 | else: 59 | self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) 60 | 61 | for name, param in self.named_parameters(): 62 | if "bias" in name: 63 | nn.init.constant_(param, 0) 64 | elif "weight" in name: 65 | # Caffe2 implementation uses MSRAFill, which in fact 66 | # corresponds to kaiming_normal_ in PyTorch 67 | nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") 68 | 69 | def forward(self, x): 70 | x = F.relu(self.conv5_mask(x)) 71 | return self.mask_fcn_logits(x), self.char_mask_fcn_logits(x) 72 | 73 | class SeqCharMaskRCNNC4Predictor(nn.Module): 74 | def __init__(self, cfg): 75 | super(SeqCharMaskRCNNC4Predictor, self).__init__() 76 | # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES 77 | num_classes = 1 78 | char_num_classes = cfg.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES 79 | dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] 80 | 81 | if cfg.MODEL.ROI_HEADS.USE_FPN: 82 | num_inputs = dim_reduced 83 | else: 84 | stage_index = 4 85 | stage2_relative_factor = 2 ** (stage_index - 1) 86 | res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS 87 | num_inputs = res2_out_channels * stage2_relative_factor 88 | 89 | self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) 90 | if cfg.MODEL.CHAR_MASK_ON: 91 | self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) 92 | self.char_mask_fcn_logits = Conv2d(dim_reduced, char_num_classes, 1, 1, 0) 93 | self.seq = make_roi_seq_predictor(cfg, dim_reduced) 94 | else: 95 | self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) 96 | 97 | for name, param in self.named_parameters(): 98 | if "bias" in name: 99 | nn.init.constant_(param, 0) 100 | elif "weight" in name: 101 | # Caffe2 implementation uses MSRAFill, which in fact 102 | # corresponds to kaiming_normal_ in PyTorch 103 | nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") 104 | 105 | def forward(self, x, decoder_targets=None, word_targets=None): 106 | x = F.relu(self.conv5_mask(x)) 107 | if self.training: 108 | loss_seq_decoder = self.seq(x, decoder_targets=decoder_targets, word_targets=word_targets) 109 | return self.mask_fcn_logits(x), self.char_mask_fcn_logits(x), loss_seq_decoder 110 | else: 111 | decoded_chars, decoded_scores, detailed_decoded_scores = self.seq(x, use_beam_search=True) 112 | return self.mask_fcn_logits(x), self.char_mask_fcn_logits(x), decoded_chars, decoded_scores, detailed_decoded_scores 113 | 114 | _ROI_MASK_PREDICTOR = {"MaskRCNNC4Predictor": MaskRCNNC4Predictor, "CharMaskRCNNC4Predictor": CharMaskRCNNC4Predictor, "SeqCharMaskRCNNC4Predictor": SeqCharMaskRCNNC4Predictor} 115 | 116 | 117 | def make_roi_mask_predictor(cfg): 118 | func = _ROI_MASK_PREDICTOR[cfg.MODEL.ROI_MASK_HEAD.PREDICTOR] 119 | return func(cfg) 120 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/roi_heads/roi_heads.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .box_head.box_head import build_roi_box_head 5 | from .mask_head.mask_head import build_roi_mask_head 6 | 7 | 8 | class CombinedROIHeads(torch.nn.ModuleDict): 9 | """ 10 | Combines a set of individual heads (for box prediction or masks) into a single 11 | head. 12 | """ 13 | 14 | def __init__(self, cfg, heads): 15 | super(CombinedROIHeads, self).__init__(heads) 16 | self.cfg = cfg.clone() 17 | if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: 18 | self.mask.feature_extractor = self.box.feature_extractor 19 | 20 | def forward(self, features, proposals, targets=None): 21 | losses = {} 22 | # TODO rename x to roi_box_features, if it doesn't increase memory consumption 23 | x, detections, loss_box = self.box(features, proposals, targets) 24 | losses.update(loss_box) 25 | if self.cfg.MODEL.MASK_ON: 26 | mask_features = features 27 | # optimization: during training, if we share the feature extractor between 28 | # the box and the mask heads, then we can reuse the features already computed 29 | if ( 30 | self.training 31 | and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR 32 | ): 33 | mask_features = x 34 | # During training, self.box() will return the unaltered proposals as "detections" 35 | # this makes the API consistent during training and testing 36 | x, detections, loss_mask = self.mask(mask_features, detections, targets) 37 | losses.update(loss_mask) 38 | return x, detections, losses 39 | 40 | 41 | def build_roi_heads(cfg): 42 | # individually create the heads, that will be combined together 43 | # afterwards 44 | roi_heads = [] 45 | if not cfg.MODEL.RPN_ONLY: 46 | roi_heads.append(("box", build_roi_box_head(cfg))) 47 | if cfg.MODEL.MASK_ON: 48 | roi_heads.append(("mask", build_roi_mask_head(cfg))) 49 | 50 | # combine individual heads in a single module 51 | if roi_heads: 52 | roi_heads = CombinedROIHeads(cfg, roi_heads) 53 | 54 | return roi_heads 55 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/rpn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # from .rpn import build_rpn 3 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/rpn/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | This file contains specific functions for computing losses on the RPN 4 | file 5 | """ 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler 11 | from ..utils import cat 12 | 13 | from maskrcnn_benchmark.layers import smooth_l1_loss 14 | from maskrcnn_benchmark.modeling.matcher import Matcher 15 | from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou 16 | from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist 17 | 18 | 19 | class RPNLossComputation(object): 20 | """ 21 | This class computes the RPN loss. 22 | """ 23 | 24 | def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): 25 | """ 26 | Arguments: 27 | proposal_matcher (Matcher) 28 | fg_bg_sampler (BalancedPositiveNegativeSampler) 29 | box_coder (BoxCoder) 30 | """ 31 | # self.target_preparator = target_preparator 32 | self.proposal_matcher = proposal_matcher 33 | self.fg_bg_sampler = fg_bg_sampler 34 | self.box_coder = box_coder 35 | 36 | def match_targets_to_anchors(self, anchor, target): 37 | match_quality_matrix = boxlist_iou(target, anchor) 38 | matched_idxs = self.proposal_matcher(match_quality_matrix) 39 | # RPN doesn't need any fields from target 40 | # for creating the labels, so clear them all 41 | target = target.copy_with_fields([]) 42 | # get the targets corresponding GT for each anchor 43 | # NB: need to clamp the indices because we can have a single 44 | # GT in the image, and matched_idxs can be -2, which goes 45 | # out of bounds 46 | matched_targets = target[matched_idxs.clamp(min=0)] 47 | matched_targets.add_field("matched_idxs", matched_idxs) 48 | return matched_targets 49 | 50 | def prepare_targets(self, anchors, targets): 51 | labels = [] 52 | regression_targets = [] 53 | for anchors_per_image, targets_per_image in zip(anchors, targets): 54 | matched_targets = self.match_targets_to_anchors( 55 | anchors_per_image, targets_per_image 56 | ) 57 | 58 | matched_idxs = matched_targets.get_field("matched_idxs") 59 | labels_per_image = matched_idxs >= 0 60 | labels_per_image = labels_per_image.to(dtype=torch.float32) 61 | # discard anchors that go out of the boundaries of the image 62 | labels_per_image[~anchors_per_image.get_field("visibility")] = -1 63 | 64 | # discard indices that are between thresholds 65 | inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS 66 | labels_per_image[inds_to_discard] = -1 67 | 68 | # compute regression targets 69 | regression_targets_per_image = self.box_coder.encode( 70 | matched_targets.bbox, anchors_per_image.bbox 71 | ) 72 | 73 | labels.append(labels_per_image) 74 | regression_targets.append(regression_targets_per_image) 75 | 76 | return labels, regression_targets 77 | 78 | def __call__(self, anchors, objectness, box_regression, targets): 79 | """ 80 | Arguments: 81 | anchors (list[BoxList]) 82 | objectness (list[Tensor]) 83 | box_regression (list[Tensor]) 84 | targets (list[BoxList]) 85 | 86 | Returns: 87 | objectness_loss (Tensor) 88 | box_loss (Tensor 89 | """ 90 | anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] 91 | labels, regression_targets = self.prepare_targets(anchors, targets) 92 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 93 | sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) 94 | sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) 95 | 96 | sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) 97 | 98 | objectness_flattened = [] 99 | box_regression_flattened = [] 100 | # for each feature level, permute the outputs to make them be in the 101 | # same format as the labels. Note that the labels are computed for 102 | # all feature levels concatenated, so we keep the same representation 103 | # for the objectness and the box_regression 104 | for objectness_per_level, box_regression_per_level in zip( 105 | objectness, box_regression 106 | ): 107 | N, A, H, W = objectness_per_level.shape 108 | objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape( 109 | N, -1 110 | ) 111 | box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W) 112 | box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2) 113 | box_regression_per_level = box_regression_per_level.reshape(N, -1, 4) 114 | objectness_flattened.append(objectness_per_level) 115 | box_regression_flattened.append(box_regression_per_level) 116 | # concatenate on the first dimension (representing the feature levels), to 117 | # take into account the way the labels were generated (with all feature maps 118 | # being concatenated as well) 119 | objectness = cat(objectness_flattened, dim=1).reshape(-1) 120 | box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) 121 | 122 | labels = torch.cat(labels, dim=0) 123 | regression_targets = torch.cat(regression_targets, dim=0) 124 | 125 | box_loss = smooth_l1_loss( 126 | box_regression[sampled_pos_inds], 127 | regression_targets[sampled_pos_inds], 128 | beta=1.0 / 9, 129 | size_average=False, 130 | ) / (sampled_inds.numel()) 131 | 132 | objectness_loss = F.binary_cross_entropy_with_logits( 133 | objectness[sampled_inds], labels[sampled_inds] 134 | ) 135 | 136 | return objectness_loss, box_loss 137 | 138 | 139 | def make_rpn_loss_evaluator(cfg, box_coder): 140 | matcher = Matcher( 141 | cfg.MODEL.RPN.FG_IOU_THRESHOLD, 142 | cfg.MODEL.RPN.BG_IOU_THRESHOLD, 143 | allow_low_quality_matches=True, 144 | ) 145 | 146 | fg_bg_sampler = BalancedPositiveNegativeSampler( 147 | cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION 148 | ) 149 | 150 | loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder) 151 | return loss_evaluator 152 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/rpn/rpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from maskrcnn_benchmark.modeling.box_coder import BoxCoder 7 | from .loss import make_rpn_loss_evaluator 8 | from .anchor_generator import make_anchor_generator 9 | from .inference import make_rpn_postprocessor 10 | 11 | 12 | class RPNHead(nn.Module): 13 | """ 14 | Adds a simple RPN Head with classification and regression heads 15 | """ 16 | 17 | def __init__(self, in_channels, num_anchors): 18 | """ 19 | Arguments: 20 | in_channels (int): number of channels of the input feature 21 | num_anchors (int): number of anchors to be predicted 22 | """ 23 | super(RPNHead, self).__init__() 24 | self.conv = nn.Conv2d( 25 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 26 | ) 27 | self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) 28 | self.bbox_pred = nn.Conv2d( 29 | in_channels, num_anchors * 4, kernel_size=1, stride=1 30 | ) 31 | 32 | for l in [self.conv, self.cls_logits, self.bbox_pred]: 33 | torch.nn.init.normal_(l.weight, std=0.01) 34 | torch.nn.init.constant_(l.bias, 0) 35 | 36 | def forward(self, x): 37 | logits = [] 38 | bbox_reg = [] 39 | for feature in x: 40 | t = F.relu(self.conv(feature)) 41 | logits.append(self.cls_logits(t)) 42 | bbox_reg.append(self.bbox_pred(t)) 43 | return logits, bbox_reg 44 | 45 | 46 | class RPNModule(torch.nn.Module): 47 | """ 48 | Module for RPN computation. Takes feature maps from the backbone and RPN 49 | proposals and losses. Works for both FPN and non-FPN. 50 | """ 51 | 52 | def __init__(self, cfg): 53 | super(RPNModule, self).__init__() 54 | 55 | self.cfg = cfg.clone() 56 | 57 | anchor_generator = make_anchor_generator(cfg) 58 | 59 | in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS 60 | head = RPNHead(in_channels, anchor_generator.num_anchors_per_location()[0]) 61 | 62 | rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) 63 | 64 | box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True) 65 | box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False) 66 | 67 | loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder) 68 | 69 | self.anchor_generator = anchor_generator 70 | self.head = head 71 | self.box_selector_train = box_selector_train 72 | self.box_selector_test = box_selector_test 73 | self.loss_evaluator = loss_evaluator 74 | 75 | def forward(self, images, features, targets=None): 76 | """ 77 | Arguments: 78 | images (ImageList): images for which we want to compute the predictions 79 | features (list[Tensor]): features computed from the images that are 80 | used for computing the predictions. Each tensor in the list 81 | correspond to different feature levels 82 | targets (list[BoxList): ground-truth boxes present in the image (optional) 83 | 84 | Returns: 85 | boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per 86 | image. 87 | losses (dict[Tensor]): the losses for the model during training. During 88 | testing, it is an empty dict. 89 | """ 90 | objectness, rpn_box_regression = self.head(features) 91 | anchors = self.anchor_generator(images, features) 92 | 93 | if self.training: 94 | return self._forward_train(anchors, objectness, rpn_box_regression, targets) 95 | else: 96 | return self._forward_test(anchors, objectness, rpn_box_regression) 97 | 98 | def _forward_train(self, anchors, objectness, rpn_box_regression, targets): 99 | if self.cfg.MODEL.RPN_ONLY: 100 | # When training an RPN-only model, the loss is determined by the 101 | # predicted objectness and rpn_box_regression values and there is 102 | # no need to transform the anchors into predicted boxes; this is an 103 | # optimization that avoids the unnecessary transformation. 104 | boxes = anchors 105 | else: 106 | # For end-to-end models, anchors must be transformed into boxes and 107 | # sampled into a training batch. 108 | with torch.no_grad(): 109 | boxes = self.box_selector_train( 110 | anchors, objectness, rpn_box_regression, targets 111 | ) 112 | loss_objectness, loss_rpn_box_reg = self.loss_evaluator( 113 | anchors, objectness, rpn_box_regression, targets 114 | ) 115 | losses = { 116 | "loss_objectness": loss_objectness, 117 | "loss_rpn_box_reg": loss_rpn_box_reg, 118 | } 119 | return boxes, losses 120 | 121 | def _forward_test(self, anchors, objectness, rpn_box_regression): 122 | boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) 123 | if self.cfg.MODEL.RPN_ONLY: 124 | # For end-to-end models, the RPN proposals are an intermediate state 125 | # and don't bother to sort them in decreasing score order. For RPN-only 126 | # models, the proposals are the final output and we return them in 127 | # high-to-low confidence order. 128 | inds = [ 129 | box.get_field("objectness").sort(descending=True)[1] for box in boxes 130 | ] 131 | boxes = [box[ind] for box, ind in zip(boxes, inds)] 132 | return boxes, {} 133 | 134 | 135 | def build_rpn(cfg): 136 | """ 137 | This gives the gist of it. Not super important because it doesn't change as much 138 | """ 139 | return RPNModule(cfg) 140 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/modeling/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | Miscellaneous utility functions 4 | """ 5 | 6 | import torch 7 | 8 | 9 | def cat(tensors, dim=0): 10 | """ 11 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 12 | """ 13 | assert isinstance(tensors, (list, tuple)) 14 | if len(tensors) == 1: 15 | return tensors[0] 16 | return torch.cat(tensors, dim) 17 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import make_optimizer 3 | from .build import make_lr_scheduler 4 | from .lr_scheduler import WarmupMultiStepLR 5 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/solver/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .lr_scheduler import WarmupMultiStepLR 5 | 6 | 7 | def make_optimizer(cfg, model): 8 | params = [] 9 | for key, value in model.named_parameters(): 10 | if not value.requires_grad: 11 | continue 12 | lr = cfg.SOLVER.BASE_LR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 14 | if "bias" in key: 15 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 17 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 18 | 19 | if cfg.SOLVER.USE_ADAM: 20 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 21 | else: 22 | optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) 23 | 24 | return optimizer 25 | 26 | 27 | def make_lr_scheduler(cfg, optimizer): 28 | return WarmupMultiStepLR( 29 | optimizer, 30 | cfg.SOLVER.STEPS, 31 | cfg.SOLVER.GAMMA, 32 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 33 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 34 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 35 | ) 36 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from bisect import bisect_right 3 | 4 | import torch 5 | 6 | 7 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 8 | # separating MultiStepLR with WarmupLR 9 | # but the current LRScheduler design doesn't allow it 10 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 11 | def __init__( 12 | self, 13 | optimizer, 14 | milestones, 15 | gamma=0.1, 16 | warmup_factor=1.0 / 3, 17 | warmup_iters=500, 18 | warmup_method="linear", 19 | last_epoch=-1, 20 | ): 21 | if not list(milestones) == sorted(milestones): 22 | raise ValueError( 23 | "Milestones should be a list of" " increasing integers. Got {}", 24 | milestones, 25 | ) 26 | 27 | if warmup_method not in ("constant", "linear"): 28 | raise ValueError( 29 | "Only 'constant' or 'linear' warmup_method accepted" 30 | "got {}".format(warmup_method) 31 | ) 32 | self.milestones = milestones 33 | self.gamma = gamma 34 | self.warmup_factor = warmup_factor 35 | self.warmup_iters = warmup_iters 36 | self.warmup_method = warmup_method 37 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | warmup_factor = 1 41 | if self.last_epoch < self.warmup_iters: 42 | if self.warmup_method == "constant": 43 | warmup_factor = self.warmup_factor 44 | elif self.warmup_method == "linear": 45 | alpha = self.last_epoch / self.warmup_iters 46 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 47 | return [ 48 | base_lr 49 | * warmup_factor 50 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 51 | for base_lr in self.base_lrs 52 | ] 53 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/structures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/maskrcnn_benchmark/structures/__init__.py -------------------------------------------------------------------------------- /maskrcnn_benchmark/structures/boxlist_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .bounding_box import BoxList 5 | 6 | from maskrcnn_benchmark.layers import nms as _box_nms 7 | 8 | 9 | def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"): 10 | """ 11 | Performs non-maximum suppression on a boxlist, with scores specified 12 | in a boxlist field via score_field. 13 | 14 | Arguments: 15 | boxlist(BoxList) 16 | nms_thresh (float) 17 | max_proposals (int): if > 0, then only the top max_proposals are kept 18 | after non-maxium suppression 19 | score_field (str) 20 | """ 21 | if nms_thresh <= 0: 22 | return boxlist 23 | mode = boxlist.mode 24 | boxlist = boxlist.convert("xyxy") 25 | boxes = boxlist.bbox 26 | score = boxlist.get_field(score_field) 27 | keep = _box_nms(boxes, score, nms_thresh) 28 | if max_proposals > 0: 29 | keep = keep[: max_proposals] 30 | boxlist = boxlist[keep] 31 | return boxlist.convert(mode) 32 | 33 | 34 | def remove_small_boxes(boxlist, min_size): 35 | """ 36 | Only keep boxes with both sides >= min_size 37 | 38 | Arguments: 39 | boxlist (Boxlist) 40 | min_size (int) 41 | """ 42 | # TODO maybe add an API for querying the ws / hs 43 | xywh_boxes = boxlist.convert("xywh").bbox 44 | _, _, ws, hs = xywh_boxes.unbind(dim=1) 45 | keep = ( 46 | (ws >= min_size) & (hs >= min_size) 47 | ).nonzero().squeeze(1) 48 | return boxlist[keep] 49 | 50 | 51 | # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py 52 | # with slight modifications 53 | def boxlist_iou(boxlist1, boxlist2): 54 | """Compute the intersection over union of two set of boxes. 55 | The box order must be (xmin, ymin, xmax, ymax). 56 | 57 | Arguments: 58 | box1: (BoxList) bounding boxes, sized [N,4]. 59 | box2: (BoxList) bounding boxes, sized [M,4]. 60 | 61 | Returns: 62 | (tensor) iou, sized [N,M]. 63 | 64 | Reference: 65 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py 66 | """ 67 | if boxlist1.size != boxlist2.size: 68 | raise RuntimeError( 69 | "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2)) 70 | 71 | N = len(boxlist1) 72 | M = len(boxlist2) 73 | 74 | area1 = boxlist1.area() 75 | area2 = boxlist2.area() 76 | 77 | box1, box2 = boxlist1.bbox, boxlist2.bbox 78 | 79 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] 80 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] 81 | 82 | TO_REMOVE = 1 83 | 84 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] 85 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 86 | 87 | iou = inter / (area1[:, None] + area2 - inter) 88 | return iou 89 | 90 | 91 | # TODO redundant, remove 92 | def _cat(tensors, dim=0): 93 | """ 94 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 95 | """ 96 | assert isinstance(tensors, (list, tuple)) 97 | if len(tensors) == 1: 98 | return tensors[0] 99 | return torch.cat(tensors, dim) 100 | 101 | 102 | def cat_boxlist(bboxes): 103 | """ 104 | Concatenates a list of BoxList (having the same image size) into a 105 | single BoxList 106 | 107 | Arguments: 108 | bboxes (list[BoxList]) 109 | """ 110 | assert isinstance(bboxes, (list, tuple)) 111 | assert all(isinstance(bbox, BoxList) for bbox in bboxes) 112 | 113 | size = bboxes[0].size 114 | assert all(bbox.size == size for bbox in bboxes) 115 | 116 | mode = bboxes[0].mode 117 | assert all(bbox.mode == mode for bbox in bboxes) 118 | 119 | fields = set(bboxes[0].fields()) 120 | assert all(set(bbox.fields()) == fields for bbox in bboxes) 121 | 122 | cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode) 123 | 124 | for field in fields: 125 | data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0) 126 | cat_boxes.add_field(field, data) 127 | 128 | return cat_boxes 129 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/structures/image_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | 5 | class ImageList(object): 6 | """ 7 | Structure that holds a list of images (of possibly 8 | varying sizes) as a single tensor. 9 | This works by padding the images to the same size, 10 | and storing in a field the original sizes of each image 11 | """ 12 | 13 | def __init__(self, tensors, image_sizes): 14 | """ 15 | Arguments: 16 | tensors (tensor) 17 | image_sizes (list[tuple[int, int]]) 18 | """ 19 | self.tensors = tensors 20 | self.image_sizes = image_sizes 21 | 22 | def to(self, *args, **kwargs): 23 | cast_tensor = self.tensors.to(*args, **kwargs) 24 | return ImageList(cast_tensor, self.image_sizes) 25 | 26 | 27 | def to_image_list(tensors, size_divisible=0): 28 | """ 29 | tensors can be an ImageList, a torch.Tensor or 30 | an iterable of Tensors. It can't be a numpy array. 31 | When tensors is an iterable of Tensors, it pads 32 | the Tensors with zeros so that they have the same 33 | shape 34 | """ 35 | if isinstance(tensors, torch.Tensor) and size_divisible > 0: 36 | tensors = [tensors] 37 | 38 | if isinstance(tensors, ImageList): 39 | return tensors 40 | elif isinstance(tensors, torch.Tensor): 41 | # single tensor shape can be inferred 42 | assert tensors.dim() == 4 43 | image_sizes = [tensor.shape[-2:] for tensor in tensors] 44 | return ImageList(tensors, image_sizes) 45 | elif isinstance(tensors, (tuple, list)): 46 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) 47 | 48 | # TODO Ideally, just remove this and let me model handle arbitrary 49 | # input sizs 50 | if size_divisible > 0: 51 | import math 52 | 53 | stride = size_divisible 54 | max_size = list(max_size) 55 | max_size[1] = int(math.ceil(max_size[1] / stride) * stride) 56 | max_size[2] = int(math.ceil(max_size[2] / stride) * stride) 57 | max_size = tuple(max_size) 58 | 59 | batch_shape = (len(tensors),) + max_size 60 | batched_imgs = tensors[0].new(*batch_shape).zero_() 61 | for img, pad_img in zip(tensors, batched_imgs): 62 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 63 | 64 | image_sizes = [im.shape[-2:] for im in tensors] 65 | 66 | return ImageList(batched_imgs, image_sizes) 67 | else: 68 | raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) 69 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/README.md: -------------------------------------------------------------------------------- 1 | # Utility functions 2 | 3 | This folder contain utility functions that are not used in the 4 | core library, but are useful for building models or training 5 | code using the config system. 6 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/MaskTextSpotter/7109132ff0ffbe77832e4b067b174d70dcc37df4/maskrcnn_benchmark/utils/__init__.py -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/c2_model_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import pickle 4 | from collections import OrderedDict 5 | 6 | import torch 7 | 8 | from maskrcnn_benchmark.utils.model_serialization import load_state_dict 9 | 10 | 11 | def _rename_basic_resnet_weights(layer_keys): 12 | layer_keys = [k.replace("_", ".") for k in layer_keys] 13 | layer_keys = [k.replace(".w", ".weight") for k in layer_keys] 14 | layer_keys = [k.replace(".bn", "_bn") for k in layer_keys] 15 | layer_keys = [k.replace(".b", ".bias") for k in layer_keys] 16 | layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys] 17 | layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys] 18 | layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys] 19 | layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys] 20 | layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys] 21 | 22 | # RPN / Faster RCNN 23 | layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys] 24 | layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys] 25 | layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys] 26 | layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys] 27 | 28 | # Affine-Channel -> BatchNorm enaming 29 | layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys] 30 | 31 | # Make torchvision-compatible 32 | layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys] 33 | 34 | layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys] 35 | layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys] 36 | layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys] 37 | layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys] 38 | 39 | layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] 40 | layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys] 41 | layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] 42 | layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys] 43 | layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] 44 | layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys] 45 | 46 | layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys] 47 | layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys] 48 | 49 | return layer_keys 50 | 51 | def _rename_fpn_weights(layer_keys, stage_names): 52 | for mapped_idx, stage_name in enumerate(stage_names, 1): 53 | suffix = "" 54 | if mapped_idx < 4: 55 | suffix = ".lateral" 56 | layer_keys = [ 57 | k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys 58 | ] 59 | layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys] 60 | 61 | 62 | layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys] 63 | layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys] 64 | layer_keys = [ 65 | k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys 66 | ] 67 | 68 | return layer_keys 69 | 70 | 71 | def _rename_weights_for_resnet(weights, stage_names): 72 | original_keys = sorted(weights.keys()) 73 | layer_keys = sorted(weights.keys()) 74 | 75 | # for X-101, rename output to fc1000 to avoid conflicts afterwards 76 | layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys] 77 | layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys] 78 | 79 | # performs basic renaming: _ -> . , etc 80 | layer_keys = _rename_basic_resnet_weights(layer_keys) 81 | 82 | # FPN 83 | layer_keys = _rename_fpn_weights(layer_keys, stage_names) 84 | 85 | # Mask R-CNN 86 | layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys] 87 | layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys] 88 | layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys] 89 | 90 | # Keypoint R-CNN 91 | layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys] 92 | layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys] 93 | layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys] 94 | 95 | # Rename for our RPN structure 96 | layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys] 97 | 98 | key_map = {k: v for k, v in zip(original_keys, layer_keys)} 99 | 100 | logger = logging.getLogger(__name__) 101 | logger.info("Remapping C2 weights") 102 | max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k]) 103 | 104 | new_weights = OrderedDict() 105 | for k in original_keys: 106 | v = weights[k] 107 | if "_momentum" in k: 108 | continue 109 | # if 'fc1000' in k: 110 | # continue 111 | w = torch.from_numpy(v) 112 | # if "bn" in k: 113 | # w = w.view(1, -1, 1, 1) 114 | logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k])) 115 | new_weights[key_map[k]] = w 116 | 117 | return new_weights 118 | 119 | 120 | def _load_c2_pickled_weights(file_path): 121 | with open(file_path, "rb") as f: 122 | data = pickle.load(f, encoding="latin1") 123 | if "blobs" in data: 124 | weights = data["blobs"] 125 | else: 126 | weights = data 127 | return weights 128 | 129 | 130 | _C2_STAGE_NAMES = { 131 | "R-50": ["1.2", "2.3", "3.5", "4.2"], 132 | "R-101": ["1.2", "2.3", "3.22", "4.2"], 133 | } 134 | 135 | def load_c2_format(cfg, f): 136 | # TODO make it support other architectures 137 | state_dict = _load_c2_pickled_weights(f) 138 | conv_body = cfg.MODEL.BACKBONE.CONV_BODY 139 | arch = conv_body.replace("-C4", "").replace("-FPN", "") 140 | stages = _C2_STAGE_NAMES[arch] 141 | state_dict = _rename_weights_for_resnet(state_dict, stages) 142 | return dict(model=state_dict) 143 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/chars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | def char2num(char): 6 | if char in '0123456789': 7 | num = ord(char) - ord('0') + 1 8 | elif char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': 9 | num = ord(char.lower()) - ord('a') + 11 10 | else: 11 | num = 0 12 | return num 13 | 14 | def num2char(num): 15 | chars = '_0123456789abcdefghijklmnopqrstuvwxyz' 16 | char = chars[num] 17 | # if num >=1 and num <=10: 18 | # char = chr(ord('0') + num - 1) 19 | # elif num > 10 and num <= 36: 20 | # char = chr(ord('a') + num - 11) 21 | # else: 22 | # print('error number:%d'%(num)) 23 | # exit() 24 | return char 25 | 26 | def getstr_grid(seg, box, threshold=192): 27 | pos = 255 - (seg[0]*255).astype(np.uint8) 28 | mask_index = np.argmax(seg, axis=0) 29 | mask_index = mask_index.astype(np.uint8) 30 | pos = pos.astype(np.uint8) 31 | string, score, rec_scores, char_polygons = seg2text(pos, mask_index, seg, box, threshold=threshold) 32 | return string, score, rec_scores, char_polygons 33 | 34 | def seg2text(gray, mask, seg, box, threshold=192): 35 | ## input numpy 36 | img_h, img_w = gray.shape 37 | box_w = box[2] - box[0] 38 | box_h = box[3] - box[1] 39 | ratio_h = float(box_h) / img_h 40 | ratio_w = float(box_w) / img_w 41 | # SE1=cv2.getStructuringElement(cv2.MORPH_RECT,(3,3)) 42 | # gray = cv2.erode(gray,SE1) 43 | # gray = cv2.dilate(gray,SE1) 44 | # gray = cv2.morphologyEx(gray,cv2.MORPH_CLOSE,SE1) 45 | ret, thresh = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY) 46 | try: 47 | _, contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 48 | except: 49 | contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 50 | chars = [] 51 | scores = [] 52 | char_polygons = [] 53 | for i in range(len(contours)): 54 | char = {} 55 | temp = np.zeros((img_h, img_w)).astype(np.uint8) 56 | cv2.drawContours(temp, [contours[i]], 0, (255), -1) 57 | x, y, w, h = cv2.boundingRect(contours[i]) 58 | c_x, c_y = x + w/2, y + h/2 59 | perimeter = cv2.arcLength(contours[i],True) 60 | epsilon = 0.01*cv2.arcLength(contours[i],True) 61 | approx = cv2.approxPolyDP(contours[i],epsilon,True) 62 | pts = approx.reshape((-1,2)) 63 | pts[:,0] = pts[:,0] * ratio_w + box[0] 64 | pts[:,1] = pts[:,1] * ratio_h + box[1] 65 | polygon = list(pts.reshape((-1,))) 66 | polygon = list(map(int, polygon)) 67 | if len(polygon)>=6: 68 | char_polygons.append(polygon) 69 | # x1 = x * ratio_w + box[0] 70 | # y1 = y * ratio_h + box[1] 71 | # x3 = (x + w) * ratio_w + box[0] 72 | # y3 = (y + h) * ratio_h + box[1] 73 | # polygon = [x1, y1, x3, y1, x3, y3, x1, y3] 74 | regions = seg[1:, temp ==255].reshape((36, -1)) 75 | cs = np.mean(regions, axis=1) 76 | sym = num2char(np.argmax(cs.reshape((-1))) + 1) 77 | char['x'] = c_x 78 | char['y'] = c_y 79 | char['s'] = sym 80 | char['cs'] = cs.reshape((-1, 1)) 81 | scores.append(np.max(char['cs'], axis=0)[0]) 82 | 83 | chars.append(char) 84 | chars = sorted(chars, key = lambda x: x['x']) 85 | string = '' 86 | css = [] 87 | for char in chars: 88 | string = string + char['s'] 89 | css.append(char['cs']) 90 | if len(scores)>0: 91 | score = sum(scores) / len(scores) 92 | else: 93 | score = 0.00 94 | if not css: 95 | css=[0.] 96 | return string, score, np.hstack(css), char_polygons 97 | 98 | def get_tight_rect(points, start_x, start_y, image_height, image_width, scale): 99 | points = list(points) 100 | ps = sorted(points,key = lambda x:x[0]) 101 | 102 | if ps[1][1] > ps[0][1]: 103 | px1 = ps[0][0] * scale + start_x 104 | py1 = ps[0][1] * scale + start_y 105 | px4 = ps[1][0] * scale + start_x 106 | py4 = ps[1][1] * scale + start_y 107 | else: 108 | px1 = ps[1][0] * scale + start_x 109 | py1 = ps[1][1] * scale + start_y 110 | px4 = ps[0][0] * scale + start_x 111 | py4 = ps[0][1] * scale + start_y 112 | if ps[3][1] > ps[2][1]: 113 | px2 = ps[2][0] * scale + start_x 114 | py2 = ps[2][1] * scale + start_y 115 | px3 = ps[3][0] * scale + start_x 116 | py3 = ps[3][1] * scale + start_y 117 | else: 118 | px2 = ps[3][0] * scale + start_x 119 | py2 = ps[3][1] * scale + start_y 120 | px3 = ps[2][0] * scale + start_x 121 | py3 = ps[2][1] * scale + start_y 122 | 123 | if px1<0: 124 | px1=1 125 | if px1>image_width: 126 | px1 = image_width - 1 127 | if px2<0: 128 | px2=1 129 | if px2>image_width: 130 | px2 = image_width - 1 131 | if px3<0: 132 | px3=1 133 | if px3>image_width: 134 | px3 = image_width - 1 135 | if px4<0: 136 | px4=1 137 | if px4>image_width: 138 | px4 = image_width - 1 139 | 140 | if py1<0: 141 | py1=1 142 | if py1>image_height: 143 | py1 = image_height - 1 144 | if py2<0: 145 | py2=1 146 | if py2>image_height: 147 | py2 = image_height - 1 148 | if py3<0: 149 | py3=1 150 | if py3>image_height: 151 | py3 = image_height - 1 152 | if py4<0: 153 | py4=1 154 | if py4>image_height: 155 | py4 = image_height - 1 156 | return [px1, py1, px2, py2, px3, py3, px4, py4] 157 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | 5 | import torch 6 | 7 | from maskrcnn_benchmark.utils.model_serialization import load_state_dict 8 | from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format 9 | from maskrcnn_benchmark.utils.imports import import_file 10 | from maskrcnn_benchmark.utils.model_zoo import cache_url 11 | 12 | 13 | class Checkpointer(object): 14 | def __init__( 15 | self, 16 | model, 17 | optimizer=None, 18 | scheduler=None, 19 | save_dir="", 20 | save_to_disk=None, 21 | logger=None, 22 | ): 23 | self.model = model 24 | self.optimizer = optimizer 25 | self.scheduler = scheduler 26 | self.save_dir = save_dir 27 | self.save_to_disk = save_to_disk 28 | if logger is None: 29 | logger = logging.getLogger(__name__) 30 | self.logger = logger 31 | 32 | def save(self, name, **kwargs): 33 | if not self.save_dir: 34 | return 35 | 36 | if not self.save_to_disk: 37 | return 38 | 39 | data = {} 40 | data["model"] = self.model.state_dict() 41 | if self.optimizer is not None: 42 | data["optimizer"] = self.optimizer.state_dict() 43 | if self.scheduler is not None: 44 | data["scheduler"] = self.scheduler.state_dict() 45 | data.update(kwargs) 46 | 47 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 48 | self.logger.info("Saving checkpoint to {}".format(save_file)) 49 | torch.save(data, save_file) 50 | self.tag_last_checkpoint(save_file) 51 | 52 | def load(self, f=None, resume=False): 53 | if self.has_checkpoint(): 54 | # override argument with existing checkpoint 55 | f = self.get_checkpoint_file() 56 | if not f: 57 | # no checkpoint could be found 58 | self.logger.info("No checkpoint found. Initializing model from scratch") 59 | return {} 60 | self.logger.info("Loading checkpoint from {}".format(f)) 61 | checkpoint = self._load_file(f) 62 | self._load_model(checkpoint) 63 | if resume: 64 | if "optimizer" in checkpoint and self.optimizer: 65 | self.logger.info("Loading optimizer from {}".format(f)) 66 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 67 | if "scheduler" in checkpoint and self.scheduler: 68 | self.logger.info("Loading scheduler from {}".format(f)) 69 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 70 | 71 | # return any further checkpoint data 72 | return checkpoint 73 | 74 | def has_checkpoint(self): 75 | save_file = os.path.join(self.save_dir, "last_checkpoint") 76 | return os.path.exists(save_file) 77 | 78 | def get_checkpoint_file(self): 79 | save_file = os.path.join(self.save_dir, "last_checkpoint") 80 | try: 81 | with open(save_file, "r") as f: 82 | last_saved = f.read() 83 | except IOError: 84 | # if file doesn't exist, maybe because it has just been 85 | # deleted by a separate process 86 | last_saved = "" 87 | return last_saved 88 | 89 | def tag_last_checkpoint(self, last_filename): 90 | save_file = os.path.join(self.save_dir, "last_checkpoint") 91 | with open(save_file, "w") as f: 92 | f.write(last_filename) 93 | 94 | def _load_file(self, f): 95 | return torch.load(f, map_location=torch.device("cpu")) 96 | 97 | def _load_model(self, checkpoint): 98 | load_state_dict(self.model, checkpoint.pop("model")) 99 | 100 | 101 | class DetectronCheckpointer(Checkpointer): 102 | def __init__( 103 | self, 104 | cfg, 105 | model, 106 | optimizer=None, 107 | scheduler=None, 108 | save_dir="", 109 | save_to_disk=None, 110 | logger=None, 111 | ): 112 | super(DetectronCheckpointer, self).__init__( 113 | model, optimizer, scheduler, save_dir, save_to_disk, logger 114 | ) 115 | self.cfg = cfg.clone() 116 | 117 | def _load_file(self, f): 118 | # catalog lookup 119 | if f.startswith("catalog://"): 120 | paths_catalog = import_file( 121 | "maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True 122 | ) 123 | catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :]) 124 | self.logger.info("{} points to {}".format(f, catalog_f)) 125 | f = catalog_f 126 | # download url files 127 | if f.startswith("http"): 128 | # if the file is a url path, download it and cache it 129 | cached_f = cache_url(f) 130 | self.logger.info("url {} cached in {}".format(f, cached_f)) 131 | f = cached_f 132 | # convert Caffe2 checkpoint from pkl 133 | if f.endswith(".pkl"): 134 | return load_c2_format(self.cfg, f) 135 | # load native detectron.pytorch checkpoint 136 | loaded = super(DetectronCheckpointer, self)._load_file(f) 137 | if "model" not in loaded: 138 | loaded = dict(model=loaded) 139 | return loaded 140 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import PIL 3 | 4 | from torch.utils.collect_env import get_pretty_env_info 5 | 6 | 7 | def get_pil_version(): 8 | return "\n Pillow ({})".format(PIL.__version__) 9 | 10 | 11 | def collect_env_info(): 12 | env_str = get_pretty_env_info() 13 | env_str += get_pil_version() 14 | return env_str 15 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import os 8 | import pickle 9 | import tempfile 10 | import time 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | 16 | def get_world_size(): 17 | if not dist.is_initialized(): 18 | return 1 19 | return dist.get_world_size() 20 | 21 | 22 | def is_main_process(): 23 | if not dist.is_initialized(): 24 | return True 25 | return dist.get_rank() == 0 26 | 27 | def get_rank(): 28 | if not dist.is_initialized(): 29 | return 0 30 | return dist.get_rank() 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize between multiple processes when 35 | using distributed training 36 | """ 37 | if not dist.is_initialized(): 38 | return 39 | world_size = dist.get_world_size() 40 | rank = dist.get_rank() 41 | if world_size == 1: 42 | return 43 | 44 | def _send_and_wait(r): 45 | if rank == r: 46 | tensor = torch.tensor(0, device="cuda") 47 | else: 48 | tensor = torch.tensor(1, device="cuda") 49 | dist.broadcast(tensor, r) 50 | while tensor.item() == 1: 51 | time.sleep(1) 52 | 53 | _send_and_wait(0) 54 | # now sync on the main process 55 | _send_and_wait(1) 56 | 57 | 58 | def _encode(encoded_data, data): 59 | # gets a byte representation for the data 60 | encoded_bytes = pickle.dumps(data) 61 | # convert this byte string into a byte tensor 62 | storage = torch.ByteStorage.from_buffer(encoded_bytes) 63 | tensor = torch.ByteTensor(storage).to("cuda") 64 | # encoding: first byte is the size and then rest is the data 65 | s = tensor.numel() 66 | assert s <= 255, "Can't encode data greater than 255 bytes" 67 | # put the encoded data in encoded_data 68 | encoded_data[0] = s 69 | encoded_data[1 : (s + 1)] = tensor 70 | 71 | 72 | def _decode(encoded_data): 73 | size = encoded_data[0] 74 | encoded_tensor = encoded_data[1 : (size + 1)].to("cpu") 75 | return pickle.loads(bytearray(encoded_tensor.tolist())) 76 | 77 | 78 | # TODO try to use tensor in shared-memory instead of serializing to disk 79 | # this involves getting the all_gather to work 80 | def scatter_gather(data): 81 | """ 82 | This function gathers data from multiple processes, and returns them 83 | in a list, as they were obtained from each process. 84 | 85 | This function is useful for retrieving data from multiple processes, 86 | when launching the code with torch.distributed.launch 87 | 88 | Note: this function is slow and should not be used in tight loops, i.e., 89 | do not use it in the training loop. 90 | 91 | Arguments: 92 | data: the object to be gathered from multiple processes. 93 | It must be serializable 94 | 95 | Returns: 96 | result (list): a list with as many elements as there are processes, 97 | where each element i in the list corresponds to the data that was 98 | gathered from the process of rank i. 99 | """ 100 | # strategy: the main process creates a temporary directory, and communicates 101 | # the location of the temporary directory to all other processes. 102 | # each process will then serialize the data to the folder defined by 103 | # the main process, and then the main process reads all of the serialized 104 | # files and returns them in a list 105 | if not dist.is_initialized(): 106 | return [data] 107 | synchronize() 108 | # get rank of the current process 109 | rank = dist.get_rank() 110 | 111 | # the data to communicate should be small 112 | data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda") 113 | if rank == 0: 114 | # manually creates a temporary directory, that needs to be cleaned 115 | # afterwards 116 | tmp_dir = tempfile.mkdtemp() 117 | _encode(data_to_communicate, tmp_dir) 118 | 119 | synchronize() 120 | # the main process (rank=0) communicates the data to all processes 121 | dist.broadcast(data_to_communicate, 0) 122 | 123 | # get the data that was communicated 124 | tmp_dir = _decode(data_to_communicate) 125 | 126 | # each process serializes to a different file 127 | file_template = "file{}.pth" 128 | tmp_file = os.path.join(tmp_dir, file_template.format(rank)) 129 | torch.save(data, tmp_file) 130 | 131 | # synchronize before loading the data 132 | synchronize() 133 | 134 | # only the master process returns the data 135 | if rank == 0: 136 | data_list = [] 137 | world_size = dist.get_world_size() 138 | for r in range(world_size): 139 | file_path = os.path.join(tmp_dir, file_template.format(r)) 140 | d = torch.load(file_path) 141 | data_list.append(d) 142 | # cleanup 143 | os.remove(file_path) 144 | # cleanup 145 | os.rmdir(tmp_dir) 146 | return data_list 147 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import os 3 | 4 | from maskrcnn_benchmark.utils.imports import import_file 5 | 6 | 7 | def setup_environment(): 8 | """Perform environment setup work. The default setup is a no-op, but this 9 | function allows the user to specify a Python source file that performs 10 | custom setup work that may be necessary to their computing environment. 11 | """ 12 | custom_module_path = os.environ.get("TORCH_DETECTRON_ENV_MODULE") 13 | if custom_module_path: 14 | setup_custom_environment(custom_module_path) 15 | else: 16 | # The default setup is a no-op 17 | pass 18 | 19 | 20 | def setup_custom_environment(custom_module_path): 21 | """Load custom environment setup from a Python source file and run the setup 22 | function. 23 | """ 24 | module = import_file("maskrcnn_benchmark.utils.env.custom_module", custom_module_path) 25 | assert hasattr(module, "setup_environment") and callable( 26 | module.setup_environment 27 | ), ( 28 | "Custom environment module defined in {} does not have the " 29 | "required callable attribute 'setup_environment'." 30 | ).format( 31 | custom_module_path 32 | ) 33 | module.setup_environment() 34 | 35 | 36 | # Force environment setup when this module is imported 37 | setup_environment() 38 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/imports.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import importlib 3 | import importlib.util 4 | import sys 5 | 6 | 7 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa 8 | def import_file(module_name, file_path, make_importable=False): 9 | spec = importlib.util.spec_from_file_location(module_name, file_path) 10 | module = importlib.util.module_from_spec(spec) 11 | spec.loader.exec_module(module) 12 | if make_importable: 13 | sys.modules[module_name] = module 14 | return module 15 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # modified by Minghui Liao 3 | import logging 4 | import os 5 | import sys 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | def setup_logger(name, save_dir, distributed_rank=0): 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | return logger 16 | ch = logging.StreamHandler(stream=sys.stdout) 17 | ch.setLevel(logging.DEBUG) 18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 19 | ch.setFormatter(formatter) 20 | logger.addHandler(ch) 21 | 22 | if save_dir: 23 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt")) 24 | fh.setLevel(logging.DEBUG) 25 | fh.setFormatter(formatter) 26 | logger.addHandler(fh) 27 | 28 | return logger 29 | 30 | 31 | class Logger(object): 32 | def __init__(self, log_dir, distributed_rank=0): 33 | """Create a summary writer logging to log_dir.""" 34 | self.distributed_rank = distributed_rank 35 | if distributed_rank == 0: 36 | self.writer = SummaryWriter(log_dir) 37 | 38 | 39 | def scalar_summary(self, tag, value, step): 40 | """Log a scalar variable.""" 41 | if self.distributed_rank == 0: 42 | self.writer.add_scalar(tag, value, step) 43 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20): 14 | self.deque = deque(maxlen=window_size) 15 | self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | class MetricLogger(object): 41 | def __init__(self, delimiter="\t"): 42 | self.meters = defaultdict(SmoothedValue) 43 | self.delimiter = delimiter 44 | 45 | def update(self, **kwargs): 46 | for k, v in kwargs.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.item() 49 | assert isinstance(v, (float, int)) 50 | self.meters[k].update(v) 51 | 52 | def __getattr__(self, attr): 53 | if attr in self.meters: 54 | return self.meters[attr] 55 | return object.__getattr__(self, attr) 56 | 57 | def __str__(self): 58 | loss_str = [] 59 | for name, meter in self.meters.items(): 60 | loss_str.append( 61 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 62 | ) 63 | return self.delimiter.join(loss_str) 64 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/miscellaneous.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import errno 3 | import os 4 | 5 | 6 | def mkdir(path): 7 | try: 8 | os.makedirs(path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | import logging 4 | 5 | import torch 6 | 7 | from maskrcnn_benchmark.utils.imports import import_file 8 | 9 | 10 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict): 11 | """ 12 | Strategy: suppose that the models that we will create will have prefixes appended 13 | to each of its keys, for example due to an extra level of nesting that the original 14 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 15 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 16 | res2.conv1.weight. We thus want to match both parameters together. 17 | For that, we look for each model weight, look among all loaded keys if there is one 18 | that is a suffix of the current weight name, and use it if that's the case. 19 | If multiple matches exist, take the one with longest size 20 | of the corresponding name. For example, for the same model as before, the pretrained 21 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 22 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 23 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 24 | """ 25 | current_keys = sorted(list(model_state_dict.keys())) 26 | loaded_keys = sorted(list(loaded_state_dict.keys())) 27 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 28 | # loaded_key string, if it matches 29 | match_matrix = [ 30 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 31 | ] 32 | match_matrix = torch.as_tensor(match_matrix).view( 33 | len(current_keys), len(loaded_keys) 34 | ) 35 | max_match_size, idxs = match_matrix.max(1) 36 | # remove indices that correspond to no-match 37 | idxs[max_match_size == 0] = -1 38 | 39 | # used for logging 40 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 41 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 42 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 43 | logger = logging.getLogger(__name__) 44 | for idx_new, idx_old in enumerate(idxs.tolist()): 45 | if idx_old == -1: 46 | continue 47 | key = current_keys[idx_new] 48 | key_old = loaded_keys[idx_old] 49 | model_state_dict[key] = loaded_state_dict[key_old] 50 | logger.info( 51 | log_str_template.format( 52 | key, 53 | max_size, 54 | key_old, 55 | max_size_loaded, 56 | tuple(loaded_state_dict[key_old].shape), 57 | ) 58 | ) 59 | 60 | 61 | def strip_prefix_if_present(state_dict, prefix): 62 | keys = sorted(state_dict.keys()) 63 | if not all(key.startswith(prefix) for key in keys): 64 | return state_dict 65 | stripped_state_dict = OrderedDict() 66 | for key, value in state_dict.items(): 67 | stripped_state_dict[key.replace(prefix, "")] = value 68 | return stripped_state_dict 69 | 70 | 71 | def load_state_dict(model, loaded_state_dict): 72 | model_state_dict = model.state_dict() 73 | # if the state_dict comes from a model that was wrapped in a 74 | # DataParallel or DistributedDataParallel during serialization, 75 | # remove the "module" prefix before performing the matching 76 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 77 | align_and_update_state_dicts(model_state_dict, loaded_state_dict) 78 | 79 | # use strict loading 80 | model.load_state_dict(model_state_dict) 81 | -------------------------------------------------------------------------------- /maskrcnn_benchmark/utils/model_zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import os 3 | import sys 4 | 5 | try: 6 | from torch.hub import _download_url_to_file 7 | from torch.hub import urlparse 8 | from torch.hub import HASH_REGEX 9 | except ImportError: 10 | from torch.utils.model_zoo import _download_url_to_file 11 | from torch.utils.model_zoo import urlparse 12 | from torch.utils.model_zoo import HASH_REGEX 13 | 14 | from maskrcnn_benchmark.utils.comm import is_main_process 15 | from maskrcnn_benchmark.utils.comm import synchronize 16 | 17 | 18 | # very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py 19 | # but with a few improvements and modifications 20 | def cache_url(url, model_dir=None, progress=True): 21 | r"""Loads the Torch serialized object at the given URL. 22 | If the object is already present in `model_dir`, it's deserialized and 23 | returned. The filename part of the URL should follow the naming convention 24 | ``filename-.ext`` where ```` is the first eight or more 25 | digits of the SHA256 hash of the contents of the file. The hash is used to 26 | ensure unique names and to verify the contents of the file. 27 | The default value of `model_dir` is ``$TORCH_HOME/models`` where 28 | ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be 29 | overridden with the ``$TORCH_MODEL_ZOO`` environment variable. 30 | Args: 31 | url (string): URL of the object to download 32 | model_dir (string, optional): directory in which to save the object 33 | progress (bool, optional): whether or not to display a progress bar to stderr 34 | Example: 35 | >>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 36 | """ 37 | if model_dir is None: 38 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 39 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 40 | if not os.path.exists(model_dir): 41 | os.makedirs(model_dir) 42 | parts = urlparse(url) 43 | filename = os.path.basename(parts.path) 44 | if filename == "model_final.pkl": 45 | # workaround as pre-trained Caffe2 models from Detectron have all the same filename 46 | # so make the full path the filename by replacing / with _ 47 | filename = parts.path.replace("/", "_") 48 | cached_file = os.path.join(model_dir, filename) 49 | if not os.path.exists(cached_file) and is_main_process(): 50 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 51 | hash_prefix = HASH_REGEX.search(filename) 52 | if hash_prefix is not None: 53 | hash_prefix = hash_prefix.group(1) 54 | # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, 55 | # which matches the hash PyTorch uses. So we skip the hash matching 56 | # if the hash_prefix is less than 6 characters 57 | if len(hash_prefix) < 6: 58 | hash_prefix = None 59 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress) 60 | synchronize() 61 | return cached_file 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #!/usr/bin/env python 3 | 4 | import glob 5 | import os 6 | 7 | import torch 8 | from setuptools import find_packages 9 | from setuptools import setup 10 | from torch.utils.cpp_extension import CUDA_HOME 11 | from torch.utils.cpp_extension import CppExtension 12 | from torch.utils.cpp_extension import CUDAExtension 13 | 14 | requirements = ["torch", "torchvision"] 15 | 16 | 17 | def get_extensions(): 18 | this_dir = os.path.dirname(os.path.abspath(__file__)) 19 | extensions_dir = os.path.join(this_dir, "maskrcnn_benchmark", "csrc") 20 | 21 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 22 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 23 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 24 | 25 | sources = main_file + source_cpu 26 | extension = CppExtension 27 | 28 | extra_compile_args = {"cxx": []} 29 | define_macros = [] 30 | 31 | if torch.cuda.is_available() and CUDA_HOME is not None: 32 | extension = CUDAExtension 33 | sources += source_cuda 34 | define_macros += [("WITH_CUDA", None)] 35 | extra_compile_args["nvcc"] = [ 36 | "-DCUDA_HAS_FP16=1", 37 | "-D__CUDA_NO_HALF_OPERATORS__", 38 | "-D__CUDA_NO_HALF_CONVERSIONS__", 39 | "-D__CUDA_NO_HALF2_OPERATORS__", 40 | ] 41 | 42 | sources = [os.path.join(extensions_dir, s) for s in sources] 43 | 44 | include_dirs = [extensions_dir] 45 | 46 | ext_modules = [ 47 | extension( 48 | "maskrcnn_benchmark._C", 49 | sources, 50 | include_dirs=include_dirs, 51 | define_macros=define_macros, 52 | extra_compile_args=extra_compile_args, 53 | ) 54 | ] 55 | 56 | return ext_modules 57 | 58 | 59 | setup( 60 | name="maskrcnn_benchmark", 61 | version="0.1", 62 | author="fmassa", 63 | url="https://github.com/facebookresearch/maskrnn-benchmark", 64 | description="object detection in pytorch", 65 | # packages=find_packages(exclude=("configs", "examples", "test",)), 66 | # install_requires=requirements, 67 | ext_modules=get_extensions(), 68 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 69 | ) 70 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python3 tools/test_net.py --config-file configs/finetune.yaml # 2 | -------------------------------------------------------------------------------- /tests/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | import os 4 | from tempfile import TemporaryDirectory 5 | import unittest 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from maskrcnn_benchmark.utils.model_serialization import load_state_dict 11 | from maskrcnn_benchmark.utils.checkpoint import Checkpointer 12 | 13 | 14 | class TestCheckpointer(unittest.TestCase): 15 | def create_model(self): 16 | return nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 1)) 17 | 18 | def create_complex_model(self): 19 | m = nn.Module() 20 | m.block1 = nn.Module() 21 | m.block1.layer1 = nn.Linear(2, 3) 22 | m.layer2 = nn.Linear(3, 2) 23 | m.res = nn.Module() 24 | m.res.layer2 = nn.Linear(3, 2) 25 | 26 | state_dict = OrderedDict() 27 | state_dict["layer1.weight"] = torch.rand(3, 2) 28 | state_dict["layer1.bias"] = torch.rand(3) 29 | state_dict["layer2.weight"] = torch.rand(2, 3) 30 | state_dict["layer2.bias"] = torch.rand(2) 31 | state_dict["res.layer2.weight"] = torch.rand(2, 3) 32 | state_dict["res.layer2.bias"] = torch.rand(2) 33 | 34 | return m, state_dict 35 | 36 | def test_from_last_checkpoint_model(self): 37 | # test that loading works even if they differ by a prefix 38 | for trained_model, fresh_model in [ 39 | (self.create_model(), self.create_model()), 40 | (nn.DataParallel(self.create_model()), self.create_model()), 41 | (self.create_model(), nn.DataParallel(self.create_model())), 42 | ( 43 | nn.DataParallel(self.create_model()), 44 | nn.DataParallel(self.create_model()), 45 | ), 46 | ]: 47 | 48 | with TemporaryDirectory() as f: 49 | checkpointer = Checkpointer( 50 | trained_model, save_dir=f, save_to_disk=True 51 | ) 52 | checkpointer.save("checkpoint_file") 53 | 54 | # in the same folder 55 | fresh_checkpointer = Checkpointer(fresh_model, save_dir=f) 56 | self.assertTrue(fresh_checkpointer.has_checkpoint()) 57 | self.assertEqual( 58 | fresh_checkpointer.get_checkpoint_file(), 59 | os.path.join(f, "checkpoint_file.pth"), 60 | ) 61 | _ = fresh_checkpointer.load() 62 | 63 | for trained_p, loaded_p in zip( 64 | trained_model.parameters(), fresh_model.parameters() 65 | ): 66 | # different tensor references 67 | self.assertFalse(id(trained_p) == id(loaded_p)) 68 | # same content 69 | self.assertTrue(trained_p.equal(loaded_p)) 70 | 71 | def test_from_name_file_model(self): 72 | # test that loading works even if they differ by a prefix 73 | for trained_model, fresh_model in [ 74 | (self.create_model(), self.create_model()), 75 | (nn.DataParallel(self.create_model()), self.create_model()), 76 | (self.create_model(), nn.DataParallel(self.create_model())), 77 | ( 78 | nn.DataParallel(self.create_model()), 79 | nn.DataParallel(self.create_model()), 80 | ), 81 | ]: 82 | with TemporaryDirectory() as f: 83 | checkpointer = Checkpointer( 84 | trained_model, save_dir=f, save_to_disk=True 85 | ) 86 | checkpointer.save("checkpoint_file") 87 | 88 | # on different folders 89 | with TemporaryDirectory() as g: 90 | fresh_checkpointer = Checkpointer(fresh_model, save_dir=g) 91 | self.assertFalse(fresh_checkpointer.has_checkpoint()) 92 | self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "") 93 | _ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth")) 94 | 95 | for trained_p, loaded_p in zip( 96 | trained_model.parameters(), fresh_model.parameters() 97 | ): 98 | # different tensor references 99 | self.assertFalse(id(trained_p) == id(loaded_p)) 100 | # same content 101 | self.assertTrue(trained_p.equal(loaded_p)) 102 | 103 | def test_complex_model_loaded(self): 104 | for add_data_parallel in [False, True]: 105 | model, state_dict = self.create_complex_model() 106 | if add_data_parallel: 107 | model = nn.DataParallel(model) 108 | 109 | load_state_dict(model, state_dict) 110 | for loaded, stored in zip(model.state_dict().values(), state_dict.values()): 111 | # different tensor references 112 | self.assertFalse(id(loaded) == id(stored)) 113 | # same content 114 | self.assertTrue(loaded.equal(stored)) 115 | 116 | 117 | if __name__ == "__main__": 118 | unittest.main() 119 | -------------------------------------------------------------------------------- /tests/test_data_samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | import random 4 | import unittest 5 | 6 | from torch.utils.data.sampler import BatchSampler 7 | from torch.utils.data.sampler import Sampler 8 | from torch.utils.data.sampler import SequentialSampler 9 | from torch.utils.data.sampler import RandomSampler 10 | 11 | from maskrcnn_benchmark.data.samplers import GroupedBatchSampler 12 | from maskrcnn_benchmark.data.samplers import IterationBasedBatchSampler 13 | 14 | 15 | class SubsetSampler(Sampler): 16 | def __init__(self, indices): 17 | self.indices = indices 18 | 19 | def __iter__(self): 20 | return iter(self.indices) 21 | 22 | def __len__(self): 23 | return len(self.indices) 24 | 25 | 26 | class TestGroupedBatchSampler(unittest.TestCase): 27 | def test_respect_order_simple(self): 28 | drop_uneven = False 29 | dataset = [i for i in range(40)] 30 | group_ids = [i // 10 for i in dataset] 31 | sampler = SequentialSampler(dataset) 32 | for batch_size in [1, 3, 5, 6]: 33 | batch_sampler = GroupedBatchSampler( 34 | sampler, group_ids, batch_size, drop_uneven 35 | ) 36 | result = list(batch_sampler) 37 | merged_result = list(itertools.chain.from_iterable(result)) 38 | self.assertEqual(merged_result, dataset) 39 | 40 | def test_respect_order(self): 41 | drop_uneven = False 42 | dataset = [i for i in range(10)] 43 | group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0] 44 | sampler = SequentialSampler(dataset) 45 | 46 | expected = [ 47 | [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], 48 | [[0, 1, 3], [2, 4, 5], [6, 9], [7, 8]], 49 | [[0, 1, 3, 6], [2, 4, 5, 7], [8], [9]], 50 | ] 51 | 52 | for idx, batch_size in enumerate([1, 3, 4]): 53 | batch_sampler = GroupedBatchSampler( 54 | sampler, group_ids, batch_size, drop_uneven 55 | ) 56 | result = list(batch_sampler) 57 | self.assertEqual(result, expected[idx]) 58 | 59 | def test_respect_order_drop_uneven(self): 60 | batch_size = 3 61 | drop_uneven = True 62 | dataset = [i for i in range(10)] 63 | group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0] 64 | sampler = SequentialSampler(dataset) 65 | batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven) 66 | 67 | result = list(batch_sampler) 68 | 69 | expected = [[0, 1, 3], [2, 4, 5]] 70 | self.assertEqual(result, expected) 71 | 72 | def test_subset_sampler(self): 73 | batch_size = 3 74 | drop_uneven = False 75 | dataset = [i for i in range(10)] 76 | group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0] 77 | sampler = SubsetSampler([0, 3, 5, 6, 7, 8]) 78 | 79 | batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven) 80 | result = list(batch_sampler) 81 | 82 | expected = [[0, 3, 6], [5, 7, 8]] 83 | self.assertEqual(result, expected) 84 | 85 | def test_permute_subset_sampler(self): 86 | batch_size = 3 87 | drop_uneven = False 88 | dataset = [i for i in range(10)] 89 | group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0] 90 | sampler = SubsetSampler([5, 0, 6, 1, 3, 8]) 91 | 92 | batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven) 93 | result = list(batch_sampler) 94 | 95 | expected = [[5, 8], [0, 6, 1], [3]] 96 | self.assertEqual(result, expected) 97 | 98 | def test_permute_subset_sampler_drop_uneven(self): 99 | batch_size = 3 100 | drop_uneven = True 101 | dataset = [i for i in range(10)] 102 | group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0] 103 | sampler = SubsetSampler([5, 0, 6, 1, 3, 8]) 104 | 105 | batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven) 106 | result = list(batch_sampler) 107 | 108 | expected = [[0, 6, 1]] 109 | self.assertEqual(result, expected) 110 | 111 | def test_len(self): 112 | batch_size = 3 113 | drop_uneven = True 114 | dataset = [i for i in range(10)] 115 | group_ids = [random.randint(0, 1) for _ in dataset] 116 | sampler = RandomSampler(dataset) 117 | 118 | batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven) 119 | result = list(batch_sampler) 120 | self.assertEqual(len(result), len(batch_sampler)) 121 | self.assertEqual(len(result), len(batch_sampler)) 122 | 123 | batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven) 124 | batch_sampler_len = len(batch_sampler) 125 | result = list(batch_sampler) 126 | self.assertEqual(len(result), batch_sampler_len) 127 | self.assertEqual(len(result), len(batch_sampler)) 128 | 129 | 130 | class TestIterationBasedBatchSampler(unittest.TestCase): 131 | def test_number_of_iters_and_elements(self): 132 | for batch_size in [2, 3, 4]: 133 | for num_iterations in [4, 10, 20]: 134 | for drop_last in [False, True]: 135 | dataset = [i for i in range(10)] 136 | sampler = SequentialSampler(dataset) 137 | batch_sampler = BatchSampler( 138 | sampler, batch_size, drop_last=drop_last 139 | ) 140 | 141 | iter_sampler = IterationBasedBatchSampler( 142 | batch_sampler, num_iterations 143 | ) 144 | assert len(iter_sampler) == num_iterations 145 | for i, batch in enumerate(iter_sampler): 146 | start = (i % len(batch_sampler)) * batch_size 147 | end = min(start + batch_size, len(dataset)) 148 | expected = [x for x in range(start, end)] 149 | self.assertEqual(batch, expected) 150 | 151 | 152 | if __name__ == "__main__": 153 | unittest.main() 154 | -------------------------------------------------------------------------------- /tools/test_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Set up custom environment before nearly anything else is imported 3 | # NOTE: this should be the first import (no not reorder) 4 | from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip 5 | 6 | import argparse 7 | import os 8 | 9 | import torch 10 | from maskrcnn_benchmark.config import cfg 11 | from maskrcnn_benchmark.data import make_data_loader 12 | from maskrcnn_benchmark.engine.text_inference import inference 13 | from maskrcnn_benchmark.modeling.detector import build_detection_model 14 | from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer 15 | from maskrcnn_benchmark.utils.collect_env import collect_env_info 16 | from maskrcnn_benchmark.utils.comm import synchronize, get_rank 17 | from maskrcnn_benchmark.utils.logging import setup_logger 18 | from maskrcnn_benchmark.utils.miscellaneous import mkdir 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference") 23 | parser.add_argument( 24 | "--config-file", 25 | default="./configs/seq.yaml", 26 | metavar="FILE", 27 | help="path to config file", 28 | ) 29 | parser.add_argument("--local_rank", type=int, default=0) 30 | parser.add_argument( 31 | "opts", 32 | help="Modify config options using the command-line", 33 | default=None, 34 | nargs=argparse.REMAINDER, 35 | ) 36 | 37 | args = parser.parse_args() 38 | 39 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 40 | distributed = num_gpus > 1 41 | 42 | if distributed: 43 | torch.cuda.set_device(args.local_rank) 44 | torch.distributed.deprecated.init_process_group( 45 | backend="nccl", init_method="env://" 46 | ) 47 | 48 | cfg.merge_from_file(args.config_file) 49 | cfg.merge_from_list(args.opts) 50 | cfg.freeze() 51 | 52 | save_dir = "" 53 | logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank()) 54 | logger.info("Using {} GPUs".format(num_gpus)) 55 | logger.info(cfg) 56 | 57 | logger.info("Collecting env info (might take some time)") 58 | logger.info("\n" + collect_env_info()) 59 | 60 | model = build_detection_model(cfg) 61 | model.to(cfg.MODEL.DEVICE) 62 | 63 | checkpointer = DetectronCheckpointer(cfg, model) 64 | _ = checkpointer.load(cfg.MODEL.WEIGHT) 65 | 66 | iou_types = ("bbox",) 67 | if cfg.MODEL.MASK_ON: 68 | iou_types = iou_types + ("segm",) 69 | output_folders = [None] * len(cfg.DATASETS.TEST) 70 | if cfg.OUTPUT_DIR: 71 | dataset_names = cfg.DATASETS.TEST 72 | for idx, dataset_name in enumerate(dataset_names): 73 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) 74 | mkdir(output_folder) 75 | output_folders[idx] = output_folder 76 | data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) 77 | model_name = cfg.MODEL.WEIGHT.split('/')[-1] 78 | for output_folder, data_loader_val in zip(output_folders, data_loaders_val): 79 | inference( 80 | model, 81 | data_loader_val, 82 | iou_types=iou_types, 83 | box_only=cfg.MODEL.RPN_ONLY, 84 | device=cfg.MODEL.DEVICE, 85 | expected_results=cfg.TEST.EXPECTED_RESULTS, 86 | expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, 87 | output_folder=output_folder, 88 | model_name=model_name, 89 | cfg=cfg, 90 | ) 91 | synchronize() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | r""" 3 | Basic training script for PyTorch 4 | """ 5 | 6 | # Set up custom environment before nearly anything else is imported 7 | # NOTE: this should be the first import (no not reorder) 8 | from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip 9 | 10 | import argparse 11 | import os 12 | 13 | import torch 14 | from maskrcnn_benchmark.config import cfg 15 | from maskrcnn_benchmark.data import make_data_loader 16 | from maskrcnn_benchmark.solver import make_lr_scheduler 17 | from maskrcnn_benchmark.solver import make_optimizer 18 | from maskrcnn_benchmark.engine.inference import inference 19 | from maskrcnn_benchmark.engine.trainer import do_train 20 | from maskrcnn_benchmark.modeling.detector import build_detection_model 21 | from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer 22 | from maskrcnn_benchmark.utils.collect_env import collect_env_info 23 | from maskrcnn_benchmark.utils.comm import synchronize, get_rank 24 | from maskrcnn_benchmark.utils.imports import import_file 25 | from maskrcnn_benchmark.utils.logging import setup_logger, Logger 26 | from maskrcnn_benchmark.utils.miscellaneous import mkdir 27 | 28 | 29 | def train(cfg, local_rank, distributed): 30 | model = build_detection_model(cfg) 31 | device = torch.device(cfg.MODEL.DEVICE) 32 | model.to(device) 33 | 34 | optimizer = make_optimizer(cfg, model) 35 | scheduler = make_lr_scheduler(cfg, optimizer) 36 | 37 | if distributed: 38 | model = torch.nn.parallel.DistributedDataParallel( 39 | model, device_ids=[local_rank], output_device=local_rank, 40 | # this should be removed if we update BatchNorm stats 41 | broadcast_buffers=False, 42 | ) 43 | 44 | arguments = {} 45 | arguments["iteration"] = 0 46 | 47 | output_dir = cfg.OUTPUT_DIR 48 | 49 | save_to_disk = get_rank() == 0 50 | checkpointer = DetectronCheckpointer( 51 | cfg, model, optimizer, scheduler, output_dir, save_to_disk 52 | ) 53 | extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, resume=cfg.SOLVER.RESUME) 54 | if cfg.SOLVER.RESUME: 55 | arguments.update(extra_checkpoint_data) 56 | 57 | data_loader = make_data_loader( 58 | cfg, 59 | is_train=True, 60 | is_distributed=distributed, 61 | start_iter=arguments["iteration"], 62 | ) 63 | 64 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 65 | tb_logger = Logger(cfg.OUTPUT_DIR) 66 | do_train( 67 | model, 68 | data_loader, 69 | optimizer, 70 | scheduler, 71 | checkpointer, 72 | device, 73 | checkpoint_period, 74 | arguments, 75 | tb_logger, 76 | cfg, 77 | ) 78 | 79 | return model 80 | 81 | 82 | def test(cfg, model, distributed): 83 | if distributed: 84 | model = model.module 85 | torch.cuda.empty_cache() # TODO check if it helps 86 | iou_types = ("bbox",) 87 | if cfg.MODEL.MASK_ON: 88 | iou_types = iou_types + ("segm",) 89 | output_folders = [None] * len(cfg.DATASETS.TEST) 90 | if cfg.OUTPUT_DIR: 91 | dataset_names = cfg.DATASETS.TEST 92 | for idx, dataset_name in enumerate(dataset_names): 93 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) 94 | mkdir(output_folder) 95 | output_folders[idx] = output_folder 96 | data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) 97 | for output_folder, data_loader_val in zip(output_folders, data_loaders_val): 98 | inference( 99 | model, 100 | data_loader_val, 101 | iou_types=iou_types, 102 | box_only=cfg.MODEL.RPN_ONLY, 103 | device=cfg.MODEL.DEVICE, 104 | expected_results=cfg.TEST.EXPECTED_RESULTS, 105 | expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, 106 | output_folder=output_folder, 107 | ) 108 | synchronize() 109 | 110 | 111 | def main(): 112 | parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") 113 | parser.add_argument( 114 | "--config-file", 115 | default="", 116 | metavar="FILE", 117 | help="path to config file", 118 | type=str, 119 | ) 120 | parser.add_argument("--local_rank", type=int, default=0) 121 | parser.add_argument( 122 | "--skip-test", 123 | dest="skip_test", 124 | help="Do not test the final model", 125 | action="store_true", 126 | ) 127 | parser.add_argument( 128 | "opts", 129 | help="Modify config options using the command-line", 130 | default=None, 131 | nargs=argparse.REMAINDER, 132 | ) 133 | 134 | args = parser.parse_args() 135 | 136 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 137 | args.distributed = num_gpus > 1 138 | 139 | if args.distributed: 140 | torch.cuda.set_device(args.local_rank) 141 | torch.distributed.init_process_group( 142 | backend="nccl", init_method="env://" 143 | ) 144 | 145 | cfg.merge_from_file(args.config_file) 146 | cfg.merge_from_list(args.opts) 147 | cfg.freeze() 148 | 149 | output_dir = cfg.OUTPUT_DIR 150 | if output_dir: 151 | mkdir(output_dir) 152 | 153 | logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank()) 154 | logger.info("Using {} GPUs".format(num_gpus)) 155 | logger.info(args) 156 | 157 | logger.info("Collecting env info (might take some time)") 158 | logger.info("\n" + collect_env_info()) 159 | 160 | logger.info("Loaded configuration file {}".format(args.config_file)) 161 | with open(args.config_file, "r") as cf: 162 | config_str = "\n" + cf.read() 163 | logger.info(config_str) 164 | logger.info("Running with config:\n{}".format(cfg)) 165 | 166 | model = train(cfg, args.local_rank, args.distributed) 167 | 168 | if not args.skip_test: 169 | test(cfg, model, args.distributed) 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/pretrain.yaml # 2 | # python3 -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/finetune.yaml # --------------------------------------------------------------------------------