├── .dev_scripts ├── batch_test.py ├── batch_test.sh ├── benchmark_filter.py ├── gather_models.py └── linter.sh ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── error-report.md │ ├── feature_request.md │ ├── general_questions.md │ └── reimplementation_questions.md └── workflows │ ├── build.yml │ └── deploy.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── LICENSE ├── README.md ├── configs └── wsod │ ├── base.py │ ├── oicr_bbox_vgg16.py │ ├── oicr_vgg16.py │ ├── wsddn_vgg16.py │ └── wsod2_vgg16.py ├── demo ├── MMDet_Tutorial.ipynb ├── demo.jpg ├── image_demo.py ├── inference_demo.ipynb └── webcam_demo.py ├── docker └── Dockerfile ├── docs ├── 1_exist_data_model.md ├── 2_new_data_model.md ├── 3_exist_data_new_model.md ├── Makefile ├── api.rst ├── changelog.md ├── compatibility.md ├── conf.py ├── conventions.md ├── faq.md ├── get_started.md ├── index.rst ├── make.bat ├── model_zoo.md ├── projects.md ├── robustness_benchmarking.md ├── stat.py ├── tutorials │ ├── config.md │ ├── customize_dataset.md │ ├── customize_losses.md │ ├── customize_models.md │ ├── customize_runtime.md │ ├── data_pipeline.md │ ├── finetune.md │ └── index.rst └── useful_tools.md ├── mmdet ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── anchor │ │ ├── __init__.py │ │ ├── anchor_generator.py │ │ ├── builder.py │ │ ├── point_generator.py │ │ └── utils.py │ ├── bbox │ │ ├── __init__.py │ │ ├── assigners │ │ │ ├── __init__.py │ │ │ ├── approx_max_iou_assigner.py │ │ │ ├── assign_result.py │ │ │ ├── atss_assigner.py │ │ │ ├── base_assigner.py │ │ │ ├── center_region_assigner.py │ │ │ ├── grid_assigner.py │ │ │ ├── max_iou_assigner.py │ │ │ └── point_assigner.py │ │ ├── builder.py │ │ ├── coder │ │ │ ├── __init__.py │ │ │ ├── base_bbox_coder.py │ │ │ ├── bucketing_bbox_coder.py │ │ │ ├── delta_xywh_bbox_coder.py │ │ │ ├── legacy_delta_xywh_bbox_coder.py │ │ │ ├── pseudo_bbox_coder.py │ │ │ ├── tblr_bbox_coder.py │ │ │ └── yolo_bbox_coder.py │ │ ├── demodata.py │ │ ├── iou_calculators │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ └── iou2d_calculator.py │ │ ├── samplers │ │ │ ├── __init__.py │ │ │ ├── base_sampler.py │ │ │ ├── combined_sampler.py │ │ │ ├── instance_balanced_pos_sampler.py │ │ │ ├── iou_balanced_neg_sampler.py │ │ │ ├── ohem_sampler.py │ │ │ ├── pseudo_sampler.py │ │ │ ├── random_sampler.py │ │ │ ├── sampling_result.py │ │ │ └── score_hlr_sampler.py │ │ └── transforms.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── bbox_overlaps.py │ │ ├── class_names.py │ │ ├── eval_hooks.py │ │ ├── mean_ap.py │ │ └── recall.py │ ├── export │ │ ├── __init__.py │ │ └── pytorch2onnx.py │ ├── fp16 │ │ ├── __init__.py │ │ └── deprecated_fp16_utils.py │ ├── mask │ │ ├── __init__.py │ │ ├── mask_target.py │ │ ├── structures.py │ │ └── utils.py │ ├── post_processing │ │ ├── __init__.py │ │ ├── bbox_nms.py │ │ └── merge_augs.py │ └── utils │ │ ├── __init__.py │ │ ├── dist_utils.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── cityscapes.py │ ├── coco.py │ ├── custom.py │ ├── dataset_wrappers.py │ ├── deepfashion.py │ ├── lvis.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── auto_augment.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── instaboost.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py │ ├── samplers │ │ ├── __init__.py │ │ ├── distributed_sampler.py │ │ └── group_sampler.py │ ├── utils.py │ ├── voc.py │ ├── voc_ss.py │ ├── wider_face.py │ └── xml_style.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── darknet.py │ │ ├── detectors_resnet.py │ │ ├── detectors_resnext.py │ │ ├── hourglass.py │ │ ├── hrnet.py │ │ ├── regnet.py │ │ ├── res2net.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── ssd_vgg.py │ │ └── vgg.py │ ├── builder.py │ ├── dense_heads │ │ ├── __init__.py │ │ ├── anchor_free_head.py │ │ ├── anchor_head.py │ │ ├── atss_head.py │ │ ├── base_dense_head.py │ │ ├── centripetal_head.py │ │ ├── corner_head.py │ │ ├── dense_test_mixins.py │ │ ├── fcos_head.py │ │ ├── fovea_head.py │ │ ├── free_anchor_retina_head.py │ │ ├── fsaf_head.py │ │ ├── ga_retina_head.py │ │ ├── ga_rpn_head.py │ │ ├── gfl_head.py │ │ ├── guided_anchor_head.py │ │ ├── nasfcos_head.py │ │ ├── paa_head.py │ │ ├── pisa_retinanet_head.py │ │ ├── pisa_ssd_head.py │ │ ├── reppoints_head.py │ │ ├── retina_head.py │ │ ├── retina_sepbn_head.py │ │ ├── rpn_head.py │ │ ├── rpn_test_mixin.py │ │ ├── sabl_retina_head.py │ │ ├── ssd_head.py │ │ ├── vfnet_head.py │ │ ├── yolact_head.py │ │ └── yolo_head.py │ ├── detectors │ │ ├── __init__.py │ │ ├── atss.py │ │ ├── base.py │ │ ├── cascade_rcnn.py │ │ ├── cornernet.py │ │ ├── fast_rcnn.py │ │ ├── faster_rcnn.py │ │ ├── fcos.py │ │ ├── fovea.py │ │ ├── fsaf.py │ │ ├── gfl.py │ │ ├── grid_rcnn.py │ │ ├── htc.py │ │ ├── mask_rcnn.py │ │ ├── mask_scoring_rcnn.py │ │ ├── nasfcos.py │ │ ├── paa.py │ │ ├── point_rend.py │ │ ├── reppoints_detector.py │ │ ├── retinanet.py │ │ ├── rpn.py │ │ ├── single_stage.py │ │ ├── two_stage.py │ │ ├── vfnet.py │ │ ├── weak_rcnn.py │ │ ├── yolact.py │ │ └── yolo.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── ae_loss.py │ │ ├── balanced_l1_loss.py │ │ ├── cross_entropy_loss.py │ │ ├── focal_loss.py │ │ ├── gaussian_focal_loss.py │ │ ├── gfocal_loss.py │ │ ├── ghm_loss.py │ │ ├── iou_loss.py │ │ ├── mse_loss.py │ │ ├── pisa_loss.py │ │ ├── smooth_l1_loss.py │ │ ├── utils.py │ │ └── varifocal_loss.py │ ├── necks │ │ ├── __init__.py │ │ ├── bfp.py │ │ ├── channel_mapper.py │ │ ├── fpn.py │ │ ├── fpn_carafe.py │ │ ├── hrfpn.py │ │ ├── nas_fpn.py │ │ ├── nasfcos_fpn.py │ │ ├── pafpn.py │ │ ├── rfp.py │ │ └── yolo_neck.py │ ├── roi_heads │ │ ├── __init__.py │ │ ├── base_roi_head.py │ │ ├── bbox_heads │ │ │ ├── __init__.py │ │ │ ├── bbox_head.py │ │ │ ├── convfc_bbox_head.py │ │ │ ├── double_bbox_head.py │ │ │ ├── oicr_head.py │ │ │ ├── sabl_head.py │ │ │ └── wsddn_head.py │ │ ├── cascade_roi_head.py │ │ ├── double_roi_head.py │ │ ├── dynamic_roi_head.py │ │ ├── grid_roi_head.py │ │ ├── htc_roi_head.py │ │ ├── mask_heads │ │ │ ├── __init__.py │ │ │ ├── coarse_mask_head.py │ │ │ ├── fcn_mask_head.py │ │ │ ├── fused_semantic_head.py │ │ │ ├── grid_head.py │ │ │ ├── htc_mask_head.py │ │ │ ├── mask_point_head.py │ │ │ └── maskiou_head.py │ │ ├── mask_scoring_roi_head.py │ │ ├── oicr_roi_head.py │ │ ├── pisa_roi_head.py │ │ ├── point_rend_roi_head.py │ │ ├── roi_extractors │ │ │ ├── __init__.py │ │ │ ├── base_roi_extractor.py │ │ │ ├── generic_roi_extractor.py │ │ │ └── single_level_roi_extractor.py │ │ ├── shared_heads │ │ │ ├── __init__.py │ │ │ └── res_layer.py │ │ ├── standard_roi_head.py │ │ ├── test_mixins.py │ │ ├── wsddn_roi_head.py │ │ └── wsod2_roi_head.py │ └── utils │ │ ├── __init__.py │ │ ├── gaussian_target.py │ │ └── res_layer.py ├── ops │ └── __init__.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── contextmanagers.py │ ├── logger.py │ ├── profiling.py │ └── util_mixins.py └── version.py ├── pytest.ini ├── requirements.txt ├── requirements ├── build.txt ├── docs.txt ├── optional.txt ├── readthedocs.txt ├── runtime.txt └── tests.txt ├── resources └── architecture.png ├── setup.cfg ├── setup.py ├── tests ├── async_benchmark.py ├── data │ ├── coco_sample.json │ ├── color.jpg │ └── gray.jpg ├── test_anchor.py ├── test_assigner.py ├── test_async.py ├── test_coder.py ├── test_config.py ├── test_data │ ├── test_dataset.py │ ├── test_formatting.py │ ├── test_img_augment.py │ ├── test_loading.py │ ├── test_models_aug_test.py │ ├── test_rotate.py │ ├── test_sampler.py │ ├── test_shear.py │ ├── test_transform.py │ ├── test_translate.py │ └── test_utils.py ├── test_eval_hook.py ├── test_fp16.py ├── test_iou2d_calculator.py ├── test_masks.py ├── test_models │ ├── test_backbones.py │ ├── test_forward.py │ ├── test_heads.py │ ├── test_losses.py │ ├── test_necks.py │ ├── test_pisa_heads.py │ └── test_roi_extractor.py └── test_version.py └── tools ├── analyze_logs.py ├── benchmark.py ├── browse_dataset.py ├── coco_error_analysis.py ├── convert_datasets ├── cityscapes.py └── pascal_voc.py ├── detectron2pytorch.py ├── dist_test.sh ├── dist_train.sh ├── eval_metric.py ├── get_flops.py ├── prepare.sh ├── print_config.py ├── publish_model.py ├── pytorch2onnx.py ├── regnet2mmdet.py ├── robustness_eval.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── test_robustness.py ├── train.py └── upgrade_model_version.py /.dev_scripts/batch_test.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=${PWD} 2 | 3 | partition=$1 4 | model_dir=$2 5 | json_out=$3 6 | job_name=batch_test 7 | gpus=8 8 | gpu_per_node=8 9 | 10 | touch $json_out 11 | lastLine=$(tail -n 1 $json_out) 12 | while [ "$lastLine" != "finished" ] 13 | do 14 | srun -p ${partition} --gres=gpu:${gpu_per_node} -n${gpus} --ntasks-per-node=${gpu_per_node} \ 15 | --job-name=${job_name} --kill-on-bad-exit=1 \ 16 | python .dev_scripts/batch_test.py $model_dir $json_out --launcher='slurm' 17 | lastLine=$(tail -n 1 $json_out) 18 | echo $lastLine 19 | done 20 | -------------------------------------------------------------------------------- /.dev_scripts/linter.sh: -------------------------------------------------------------------------------- 1 | yapf -r -i mmdet/ configs/ tests/ tools/ 2 | isort -rc mmdet/ configs/ tests/ tools/ 3 | flake8 . 4 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at chenkaidev@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mmdetection 2 | 3 | All kinds of contributions are welcome, including but not limited to the following. 4 | 5 | - Fixes (typo, bugs) 6 | - New features and components 7 | 8 | ## Workflow 9 | 10 | 1. fork and pull the latest mmdetection 11 | 2. checkout a new branch (do not use master branch for PRs) 12 | 3. commit your changes 13 | 4. create a PR 14 | 15 | Note 16 | - If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first. 17 | - If you are the author of some papers and would like to include your method to mmdetection, 18 | please let us know (open an issue or contact the maintainers). We will much appreciate your contribution. 19 | - For new features and new modules, unit tests are required to improve the code's robustness. 20 | 21 | ## Code style 22 | 23 | ### Python 24 | We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. 25 | 26 | We use the following tools for linting and formatting: 27 | - [flake8](http://flake8.pycqa.org/en/latest/): linter 28 | - [yapf](https://github.com/google/yapf): formatter 29 | - [isort](https://github.com/timothycrosley/isort): sort imports 30 | 31 | Style configurations of yapf and isort can be found in [setup.cfg](../setup.cfg). 32 | 33 | We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, 34 | fixes `end-of-files`, sorts `requirments.txt` automatically on every commit. 35 | The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml). 36 | 37 | After you clone the repository, you will need to install initialize pre-commit hook. 38 | 39 | ``` 40 | pip install -U pre-commit 41 | ``` 42 | 43 | From the repository folder 44 | ``` 45 | pre-commit install 46 | ``` 47 | 48 | After this on every commit check code linters and formatter will be enforced. 49 | 50 | 51 | >Before you create a PR, make sure that your code lints and is formatted by yapf. 52 | 53 | ### C++ and CUDA 54 | We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). 55 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | 3 | contact_links: 4 | - name: Common Issues 5 | url: https://mmdetection.readthedocs.io/en/latest/faq.html 6 | about: Check if your issue already has solutions 7 | - name: MMDetection Documentation 8 | url: https://mmdetection.readthedocs.io/en/latest/ 9 | about: Check if your question is answered in docs 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/error-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Error report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Thanks for your error report and we appreciate it a lot. 11 | 12 | **Checklist** 13 | 1. I have searched related issues but cannot get the expected help. 14 | 2. The bug has not been fixed in the latest version. 15 | 16 | **Describe the bug** 17 | A clear and concise description of what the bug is. 18 | 19 | **Reproduction** 20 | 1. What command or script did you run? 21 | ``` 22 | A placeholder for the command. 23 | ``` 24 | 2. Did you make any modifications on the code or config? Did you understand what you have modified? 25 | 3. What dataset did you use? 26 | 27 | **Environment** 28 | 29 | 1. Please run `python mmdet/utils/collect_env.py` to collect necessary environment information and paste it here. 30 | 2. You may add addition that may be helpful for locating the problem, such as 31 | - How you installed PyTorch [e.g., pip, conda, source] 32 | - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) 33 | 34 | **Error traceback** 35 | If applicable, paste the error trackback here. 36 | ``` 37 | A placeholder for trackback. 38 | ``` 39 | 40 | **Bug fix** 41 | If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated! 42 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the feature** 11 | 12 | **Motivation** 13 | A clear and concise description of the motivation of the feature. 14 | Ex1. It is inconvenient when [....]. 15 | Ex2. There is a recent paper [....], which is very helpful for [....]. 16 | 17 | **Related resources** 18 | If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful. 19 | 20 | **Additional context** 21 | Add any other context or screenshots about the feature request here. 22 | If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated. 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/general_questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General questions 3 | about: Ask general questions to get help 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/reimplementation_questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Reimplementation Questions 3 | about: Ask about questions during model reimplementation 4 | title: '' 5 | labels: 'reimplementation' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Notice** 11 | 12 | There are several common situations in the reimplementation issues as below 13 | 1. Reimplement a model in the model zoo using the provided configs 14 | 2. Reimplement a model in the model zoo on other dataset (e.g., custom datasets) 15 | 3. Reimplement a custom model but all the components are implemented in MMDetection 16 | 4. Reimplement a custom model with new modules implemented by yourself 17 | 18 | There are several things to do for different cases as below. 19 | - For case 1 & 3, please follow the steps in the following sections thus we could help to quick identify the issue. 20 | - For case 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code and the users should be responsible to the code they write. 21 | - One suggestion for case 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections and try as clear as possible so that we can better help you. 22 | 23 | **Checklist** 24 | 1. I have searched related issues but cannot get the expected help. 25 | 2. The issue has not been fixed in the latest version. 26 | 27 | **Describe the issue** 28 | 29 | A clear and concise description of what the problem you meet and what have you done. 30 | 31 | **Reproduction** 32 | 1. What command or script did you run? 33 | ``` 34 | A placeholder for the command. 35 | ``` 36 | 2. What config dir you run? 37 | ``` 38 | A placeholder for the config. 39 | ``` 40 | 3. Did you make any modifications on the code or config? Did you understand what you have modified? 41 | 4. What dataset did you use? 42 | 43 | **Environment** 44 | 45 | 1. Please run `python mmdet/utils/collect_env.py` to collect necessary environment information and paste it here. 46 | 2. You may add addition that may be helpful for locating the problem, such as 47 | - How you installed PyTorch [e.g., pip, conda, source] 48 | - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) 49 | 50 | **Results** 51 | 52 | If applicable, paste the related results here, e.g., what you expect and what you get. 53 | ``` 54 | A placeholder for results comparison 55 | ``` 56 | 57 | **Issue fix** 58 | 59 | If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated! 60 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: deploy 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | runs-on: ubuntu-latest 8 | if: startsWith(github.event.ref, 'refs/tags') 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Set up Python 3.7 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.7 15 | - name: Install torch 16 | run: pip install torch 17 | - name: Install wheel 18 | run: pip install wheel 19 | - name: Build MMDetection 20 | run: python setup.py sdist bdist_wheel 21 | - name: Publish distribution to PyPI 22 | run: | 23 | pip install twine 24 | twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/ 107 | data 108 | .vscode 109 | .idea 110 | .DS_Store 111 | 112 | # custom 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | 118 | # Pytorch 119 | *.pth 120 | *.py~ 121 | *.sh~ 122 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitlab.com/pycqa/flake8.git 3 | rev: 3.8.3 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/asottile/seed-isort-config 7 | rev: v2.2.0 8 | hooks: 9 | - id: seed-isort-config 10 | - repo: https://github.com/timothycrosley/isort 11 | rev: 4.3.21 12 | hooks: 13 | - id: isort 14 | - repo: https://github.com/pre-commit/mirrors-yapf 15 | rev: v0.30.0 16 | hooks: 17 | - id: yapf 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v3.1.0 20 | hooks: 21 | - id: trailing-whitespace 22 | - id: check-yaml 23 | - id: end-of-file-fixer 24 | - id: requirements-txt-fixer 25 | - id: double-quote-string-fixer 26 | - id: check-merge-conflict 27 | - id: fix-encoding-pragma 28 | args: ["--remove"] 29 | - id: mixed-line-ending 30 | args: ["--fix=lf"] 31 | - repo: https://github.com/myint/docformatter 32 | rev: v1.3.1 33 | hooks: 34 | - id: docformatter 35 | args: ["--in-place", "--wrap-descriptions", "79"] 36 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | version: 3.7 5 | install: 6 | - requirements: requirements/docs.txt 7 | - requirements: requirements/readthedocs.txt 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WSOD^2: Learning Bottom-up and Top-down Objectness Distillation for Weakly-supervised Object Detection 2 | By Zhaoyang Zeng, Bei Liu, Jianlong Fu, Hongyang Chao, and Lei Zhang 3 | 4 | ### Introduction 5 | This repo is a toolkit for weakly supervised object detection based on [mmdetection](https://github.com/open-mmlab/mmdetection), including the implementation of [WSDDN](https://arxiv.org/abs/1511.02853), [OICR](https://arxiv.org/abs/1704.00138) and [WSOD^2](https://arxiv.org/abs/1909.04972). The implementation is slightly different from the original papers, including but not limited to 6 | * optimizer 7 | * training epoch 8 | * learning rate 9 | * input resolution 10 | * pseudo GTs mining 11 | * loss weight assignment 12 | 13 | The baselines in this rpo can easily achieve 48+ mAP on Pascal VOC 2007 dataset. Some hyperparameters are still tuned, they should bring more performance gain. 14 | 15 | ### Architecture 16 | 17 |

18 | WSOD^2 architecture 19 |

20 | 21 | ### Results 22 | 23 | | Method | VOC2007 test *mAP* | VOC2007 trainval *CorLoc* | VOC2012 test *mAP* | VOC2012 trainval *CorLoc* 24 | |:-------|:-----:|:-------:|:-------:|:-------:| 25 | | WSOD2 | 53.6 | 71.4 | 47.2 | 71.9 | 26 | | WSOD2\* | 56.0 | 71.4 | 52.7 | 72.2 | 27 | 28 | \* denotes training on VOC 07+12 *trainval* splits 29 | 30 | ### Installation 31 | 32 | Please refere to [here](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md) for installation 33 | 34 | ### Getting Started 35 | 36 | 1. Download the training, validation and test data, and unzip 37 | ```shell 38 | mkdir -p $WSOD_ROOT/data/voc 39 | cd $WSOD_ROOT/data/voc 40 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 41 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 42 | tar xf VOCtrainval_06-Nov-2007.tar 43 | tar xf VOCtest_06-Nov-2007.tar 44 | ``` 45 | 46 | 2. Download the ImageNet pre-trained models, selective search boxes and superpixels 47 | ```shell 48 | bash $WSOD_ROOT/tools/prepare.sh 49 | ``` 50 | 51 | If you can not access google drive, you also can download the resources from [https://pan.baidu.com/s/1htyljhvYz5qwO-4oH8C3wg](https://pan.baidu.com/s/1htyljhvYz5qwO-4oH8C3wg) (password: u5r3) and unzip them, the directory structure should be like 52 | ``` 53 | data 54 | - VOCdevkit 55 | - VOC2007 56 | - voc_2007_trainval.pkl 57 | - voc_2007_test.pkl 58 | - SuperPixels 59 | - VOC2012 60 | - voc_2012_trainval.pkl 61 | - voc_2012_test.pkl 62 | - SuperPixels 63 | pretrain 64 | - vgg16.pth 65 | ``` 66 | 67 | 3. Training a wsod model 68 | ```shell 69 | bash tools/dist_train.sh $config $num_gpus 70 | ``` 71 | 72 | 4. Evaluate a wsod model 73 | ```shell 74 | bash tools/dist_test.sh $config $checkpoint $num_gpus --eval mAP 75 | ``` 76 | 77 | ### License 78 | WSOD2 is released under the MIT License. 79 | 80 | ### Citing WSOD2 81 | 82 | If your find this repo useful in your research, please consider citing: 83 | 84 | ```BibTex 85 | @inproceedings{zeng2019wsod2, 86 | title={Wsod2: Learning bottom-up and top-down objectness distillation for weakly-supervised object detection}, 87 | author={Zeng, Zhaoyang and Liu, Bei and Fu, Jianlong and Chao, Hongyang and Zhang, Lei}, 88 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 89 | pages={8292--8300}, 90 | year={2019} 91 | } 92 | ``` 93 | 94 | 95 | -------------------------------------------------------------------------------- /configs/wsod/base.py: -------------------------------------------------------------------------------- 1 | # model training and testing settings 2 | train_cfg = dict( 3 | rcnn=dict()) 4 | test_cfg = dict( 5 | rcnn=dict( 6 | score_thr=0.0000, 7 | nms=dict(type='nms', iou_threshold=0.3), 8 | max_per_img=100)) 9 | 10 | # dataset settings 11 | dataset_type = 'VOCDataset' 12 | data_root = '/datavoc/VOCdevkit/' 13 | img_norm_cfg = dict( 14 | mean=[104., 117., 124.], std=[1., 1., 1.], to_rgb=False) 15 | train_pipeline = [ 16 | dict(type='LoadImageFromFile'), 17 | dict(type='LoadWeakAnnotations'), 18 | dict(type='LoadProposals'), 19 | dict(type='Resize', img_scale=[(488, 2000), (576, 2000), (688, 2000), (864, 2000), (1200, 2000)], keep_ratio=True, multiscale_mode='range'), 20 | dict(type='RandomFlip', flip_ratio=0.5), 21 | dict(type='Normalize', **img_norm_cfg), 22 | dict(type='Pad', size_divisor=32), 23 | dict(type='DefaultFormatBundle'), 24 | dict(type='Collect', keys=['img', 'gt_labels', 'proposals']), 25 | ] 26 | test_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadProposals'), 29 | dict( 30 | type='MultiScaleFlipAug', 31 | img_scale=(688, 2000), 32 | #img_scale=[(500, 2000), (600, 2000), (700, 2000), (800, 2000), (900, 2000)], 33 | flip=False, 34 | transforms=[ 35 | dict(type='Resize', keep_ratio=True), 36 | dict(type='RandomFlip'), 37 | dict(type='Normalize', **img_norm_cfg), 38 | dict(type='Pad', size_divisor=32), 39 | dict(type='ImageToTensor', keys=['img']), 40 | dict(type='Collect', keys=['img', 'proposals']), 41 | ]) 42 | ] 43 | data = dict( 44 | samples_per_gpu=1, 45 | workers_per_gpu=2, 46 | train=dict( 47 | type=dataset_type, 48 | ann_file=data_root + 'VOC2007/ImageSets/Main/trainval.txt', 49 | img_prefix=data_root + 'VOC2007/', 50 | proposal_file='/datavoc/selective_search_data/voc_2007_trainval.pkl', 51 | pipeline=train_pipeline), 52 | val=dict( 53 | type=dataset_type, 54 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 55 | img_prefix=data_root + 'VOC2007/', 56 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl', 57 | pipeline=test_pipeline), 58 | test=dict( 59 | type=dataset_type, 60 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 61 | img_prefix=data_root + 'VOC2007/', 62 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl', 63 | pipeline=test_pipeline)) 64 | evaluation = dict(interval=1, metric='mAP') 65 | 66 | # optimizer 67 | optimizer = dict( 68 | type='Adam', 69 | lr=1e-5, 70 | weight_decay=0.0005, 71 | paramwise_cfg=dict( 72 | bias_decay_mult=0., 73 | bias_lr_mult=2., 74 | custom_keys={ 75 | 'refine': dict(lr_mult=10), 76 | }) 77 | ) 78 | 79 | optimizer_config = dict(grad_clip=None) 80 | # learning policy 81 | lr_config = dict( 82 | policy='step', 83 | warmup='linear', 84 | warmup_iters=500, 85 | warmup_ratio=0.001, 86 | step=[36]) 87 | total_epochs = 64 88 | 89 | checkpoint_config = dict(interval=16) 90 | # yapf:disable 91 | log_config = dict( 92 | interval=100, 93 | hooks=[ 94 | dict(type='TextLoggerHook'), 95 | # dict(type='TensorboardLoggerHook') 96 | ]) 97 | # yapf:enable 98 | dist_params = dict(backend='nccl') 99 | log_level = 'INFO' 100 | load_from = 'pretrain/vgg16_v2.pth' 101 | resume_from = None 102 | workflow = [('train', 1)] 103 | -------------------------------------------------------------------------------- /configs/wsod/oicr_bbox_vgg16.py: -------------------------------------------------------------------------------- 1 | _base_ = './base.py' 2 | # model settings 3 | model = dict( 4 | type='WeakRCNN', 5 | pretrained=None, 6 | backbone=dict(type='VGG16'), 7 | neck=None, 8 | roi_head=dict( 9 | type='OICRRoIHead', 10 | bbox_roi_extractor=dict( 11 | type='SingleRoIExtractor', 12 | roi_layer=dict(type='RoIPool', output_size=7), 13 | out_channels=512, 14 | featmap_strides=[8]), 15 | bbox_head=dict( 16 | type='OICRHead', 17 | in_channels=512, 18 | hidden_channels=4096, 19 | roi_feat_size=7, 20 | bbox_coder=dict( 21 | type='DeltaXYWHBBoxCoder', 22 | target_means=[0., 0., 0., 0.], 23 | target_stds=[0.1, 0.1, 0.2, 0.2]), 24 | num_classes=20)) 25 | ) 26 | work_dir = 'work_dirs/oicr_bbox_vgg16/' 27 | -------------------------------------------------------------------------------- /configs/wsod/oicr_vgg16.py: -------------------------------------------------------------------------------- 1 | _base_ = './base.py' 2 | # model settings 3 | model = dict( 4 | type='WeakRCNN', 5 | pretrained=None, 6 | backbone=dict(type='VGG16'), 7 | neck=None, 8 | roi_head=dict( 9 | type='OICRRoIHead', 10 | bbox_roi_extractor=dict( 11 | type='SingleRoIExtractor', 12 | roi_layer=dict(type='RoIPool', output_size=7), 13 | out_channels=512, 14 | featmap_strides=[8]), 15 | bbox_head=dict( 16 | type='OICRHead', 17 | in_channels=512, 18 | hidden_channels=4096, 19 | roi_feat_size=7, 20 | num_classes=20)) 21 | ) 22 | work_dir = 'work_dirs/oicr_vgg16/' 23 | -------------------------------------------------------------------------------- /configs/wsod/wsddn_vgg16.py: -------------------------------------------------------------------------------- 1 | _base_ = './base.py' 2 | # model settings 3 | model = dict( 4 | type='WeakRCNN', 5 | pretrained=None, 6 | backbone=dict(type='VGG16'), 7 | neck=None, 8 | roi_head=dict( 9 | type='WSDDNRoIHead', 10 | bbox_roi_extractor=dict( 11 | type='SingleRoIExtractor', 12 | roi_layer=dict(type='RoIPool', output_size=7), 13 | out_channels=512, 14 | featmap_strides=[8]), 15 | bbox_head=dict( 16 | type='WSDDNHead', 17 | in_channels=512, 18 | hidden_channels=4096, 19 | roi_feat_size=7, 20 | num_classes=20)) 21 | ) 22 | work_dir = 'work_dirs/wsddn_vgg16/' 23 | -------------------------------------------------------------------------------- /configs/wsod/wsod2_vgg16.py: -------------------------------------------------------------------------------- 1 | _base_ = './base.py' 2 | # model settings 3 | model = dict( 4 | type='WeakRCNN', 5 | pretrained=None, 6 | backbone=dict(type='VGG16'), 7 | neck=None, 8 | roi_head=dict( 9 | type='WSOD2RoIHead', 10 | steps=40000, 11 | bbox_roi_extractor=dict( 12 | type='SingleRoIExtractor', 13 | roi_layer=dict(type='RoIPool', output_size=7), 14 | out_channels=512, 15 | featmap_strides=[8]), 16 | bbox_head=dict( 17 | type='OICRHead', 18 | in_channels=512, 19 | hidden_channels=4096, 20 | roi_feat_size=7, 21 | bbox_coder=dict( 22 | type='DeltaXYWHBBoxCoder', 23 | target_means=[0., 0., 0., 0.], 24 | target_stds=[0.1, 0.1, 0.2, 0.2]), 25 | num_classes=20)) 26 | ) 27 | # dataset settings 28 | dataset_type = 'VOCSSDataset' 29 | data_root = '/datavoc/VOCdevkit/' 30 | img_norm_cfg = dict( 31 | mean=[104., 117., 124.], std=[1., 1., 1.], to_rgb=False) 32 | train_pipeline = [ 33 | dict(type='LoadImageFromFile'), 34 | dict(type='LoadSuperPixelFromFile'), 35 | dict(type='LoadWeakAnnotations'), 36 | dict(type='LoadProposals'), 37 | dict(type='Resize', img_scale=[(488, 2000), (576, 2000), (688, 2000), (864, 2000), (1200, 2000)], keep_ratio=True, multiscale_mode='value'), 38 | dict(type='RandomFlip', flip_ratio=0.5), 39 | dict(type='Normalize', **img_norm_cfg), 40 | dict(type='Pad', size_divisor=32), 41 | dict(type='DefaultFormatBundle'), 42 | dict(type='Collect', keys=['img', 'gt_labels', 'proposals', 'ss']), 43 | ] 44 | test_pipeline = [ 45 | dict(type='LoadImageFromFile'), 46 | dict(type='LoadProposals'), 47 | dict( 48 | type='MultiScaleFlipAug', 49 | img_scale=(688, 2000), 50 | #img_scale=[(500, 2000), (600, 2000), (700, 2000), (800, 2000), (900, 2000)], 51 | flip=False, 52 | transforms=[ 53 | dict(type='Resize', keep_ratio=True), 54 | dict(type='RandomFlip'), 55 | dict(type='Normalize', **img_norm_cfg), 56 | dict(type='Pad', size_divisor=32), 57 | dict(type='ImageToTensor', keys=['img']), 58 | dict(type='Collect', keys=['img', 'proposals']), 59 | ]) 60 | ] 61 | data = dict( 62 | samples_per_gpu=1, 63 | workers_per_gpu=2, 64 | train=dict( 65 | type=dataset_type, 66 | ann_file=data_root + 'VOC2007/ImageSets/Main/trainval.txt', 67 | img_prefix=data_root + 'VOC2007/', 68 | proposal_file='/datavoc/selective_search_data/voc_2007_trainval.pkl', 69 | pipeline=train_pipeline), 70 | val=dict( 71 | type=dataset_type, 72 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 73 | img_prefix=data_root + 'VOC2007/', 74 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl', 75 | pipeline=test_pipeline), 76 | test=dict( 77 | type=dataset_type, 78 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 79 | img_prefix=data_root + 'VOC2007/', 80 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl', 81 | pipeline=test_pipeline)) 82 | 83 | work_dir = 'work_dirs/wsod2_vgg16/' 84 | -------------------------------------------------------------------------------- /demo/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/demo/demo.jpg -------------------------------------------------------------------------------- /demo/image_demo.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mmdet.apis import inference_detector, init_detector, show_result_pyplot 4 | 5 | 6 | def main(): 7 | parser = ArgumentParser() 8 | parser.add_argument('img', help='Image file') 9 | parser.add_argument('config', help='Config file') 10 | parser.add_argument('checkpoint', help='Checkpoint file') 11 | parser.add_argument( 12 | '--device', default='cuda:0', help='Device used for inference') 13 | parser.add_argument( 14 | '--score-thr', type=float, default=0.3, help='bbox score threshold') 15 | args = parser.parse_args() 16 | 17 | # build the model from a config file and a checkpoint file 18 | model = init_detector(args.config, args.checkpoint, device=args.device) 19 | # test a single image 20 | result = inference_detector(model, args.img) 21 | # show the results 22 | show_result_pyplot(model, args.img, result, score_thr=args.score_thr) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /demo/webcam_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import torch 5 | 6 | from mmdet.apis import inference_detector, init_detector 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='MMDetection webcam demo') 11 | parser.add_argument('config', help='test config file path') 12 | parser.add_argument('checkpoint', help='checkpoint file') 13 | parser.add_argument( 14 | '--device', type=str, default='cuda:0', help='CPU/CUDA device option') 15 | parser.add_argument( 16 | '--camera-id', type=int, default=0, help='camera device id') 17 | parser.add_argument( 18 | '--score-thr', type=float, default=0.5, help='bbox score threshold') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | device = torch.device(args.device) 27 | 28 | model = init_detector(args.config, args.checkpoint, device=device) 29 | 30 | camera = cv2.VideoCapture(args.camera_id) 31 | 32 | print('Press "Esc", "q" or "Q" to exit.') 33 | while True: 34 | ret_val, img = camera.read() 35 | result = inference_detector(model, img) 36 | 37 | ch = cv2.waitKey(1) 38 | if ch == 27 or ch == ord('q') or ch == ord('Q'): 39 | break 40 | 41 | model.show_result( 42 | img, result, score_thr=args.score_thr, wait_time=1, show=True) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.6.0" 2 | ARG CUDA="10.1" 3 | ARG CUDNN="7" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 9 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" 10 | 11 | RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # Install MMCV 16 | RUN pip install mmcv-full==latest+torch1.6.0+cu101 -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html 17 | 18 | # Install MMDetection 19 | RUN conda clean --all 20 | RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection 21 | WORKDIR /mmdetection 22 | ENV FORCE_CUDA="1" 23 | RUN pip install -r requirements/build.txt 24 | RUN pip install --no-cache-dir -e . 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ================= 3 | 4 | mmdet.apis 5 | -------------- 6 | .. automodule:: mmdet.apis 7 | :members: 8 | 9 | mmdet.core 10 | -------------- 11 | 12 | anchor 13 | ^^^^^^^^^^ 14 | .. automodule:: mmdet.core.anchor 15 | :members: 16 | 17 | bbox 18 | ^^^^^^^^^^ 19 | .. automodule:: mmdet.core.bbox 20 | :members: 21 | 22 | export 23 | ^^^^^^^^^^ 24 | .. automodule:: mmdet.core.export 25 | :members: 26 | 27 | mask 28 | ^^^^^^^^^^ 29 | .. automodule:: mmdet.core.mask 30 | :members: 31 | 32 | evaluation 33 | ^^^^^^^^^^ 34 | .. automodule:: mmdet.core.evaluation 35 | :members: 36 | 37 | post_processing 38 | ^^^^^^^^^^^^^^^ 39 | .. automodule:: mmdet.core.post_processing 40 | :members: 41 | 42 | optimizer 43 | ^^^^^^^^^^ 44 | .. automodule:: mmdet.core.optimizer 45 | :members: 46 | 47 | utils 48 | ^^^^^^^^^^ 49 | .. automodule:: mmdet.core.utils 50 | :members: 51 | 52 | mmdet.datasets 53 | -------------- 54 | 55 | datasets 56 | ^^^^^^^^^^ 57 | .. automodule:: mmdet.datasets 58 | :members: 59 | 60 | pipelines 61 | ^^^^^^^^^^ 62 | .. automodule:: mmdet.datasets.pipelines 63 | :members: 64 | 65 | mmdet.models 66 | -------------- 67 | 68 | detectors 69 | ^^^^^^^^^^ 70 | .. automodule:: mmdet.models.detectors 71 | :members: 72 | 73 | backbones 74 | ^^^^^^^^^^ 75 | .. automodule:: mmdet.models.backbones 76 | :members: 77 | 78 | necks 79 | ^^^^^^^^^^^^ 80 | .. automodule:: mmdet.models.necks 81 | :members: 82 | 83 | dense_heads 84 | ^^^^^^^^^^^^ 85 | .. automodule:: mmdet.models.dense_heads 86 | :members: 87 | 88 | roi_heads 89 | ^^^^^^^^^^ 90 | .. automodule:: mmdet.models.roi_heads 91 | :members: 92 | 93 | losses 94 | ^^^^^^^^^^ 95 | .. automodule:: mmdet.models.losses 96 | :members: 97 | 98 | utils 99 | ^^^^^^^^^^ 100 | .. automodule:: mmdet.models.utils 101 | :members: 102 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import subprocess 15 | import sys 16 | 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'MMDetection' 22 | copyright = '2018-2020, OpenMMLab' 23 | author = 'MMDetection Authors' 24 | version_file = '../mmdet/version.py' 25 | 26 | 27 | def get_version(): 28 | with open(version_file, 'r') as f: 29 | exec(compile(f.read(), version_file, 'exec')) 30 | return locals()['__version__'] 31 | 32 | 33 | # The full version, including alpha/beta/rc tags 34 | release = get_version() 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.napoleon', 44 | 'sphinx.ext.viewcode', 45 | 'recommonmark', 46 | 'sphinx_markdown_tables', 47 | ] 48 | 49 | autodoc_mock_imports = [ 50 | 'matplotlib', 'pycocotools', 'terminaltables', 'mmdet.version', 'mmcv.ops' 51 | ] 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ['_templates'] 55 | 56 | # The suffix(es) of source filenames. 57 | # You can specify multiple suffix as a list of string: 58 | # 59 | source_suffix = { 60 | '.rst': 'restructuredtext', 61 | '.md': 'markdown', 62 | } 63 | 64 | # The master toctree document. 65 | master_doc = 'index' 66 | 67 | # List of patterns, relative to source directory, that match files and 68 | # directories to ignore when looking for source files. 69 | # This pattern also affects html_static_path and html_extra_path. 70 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 71 | 72 | # -- Options for HTML output ------------------------------------------------- 73 | 74 | # The theme to use for HTML and HTML Help pages. See the documentation for 75 | # a list of builtin themes. 76 | # 77 | html_theme = 'sphinx_rtd_theme' 78 | 79 | # Add any paths that contain custom static files (such as style sheets) here, 80 | # relative to this directory. They are copied after the builtin static files, 81 | # so a file named "default.css" will overwrite the builtin "default.css". 82 | html_static_path = ['_static'] 83 | 84 | 85 | def builder_inited_handler(app): 86 | subprocess.run(['./stat.py']) 87 | 88 | 89 | def setup(app): 90 | app.connect('builder-inited', builder_inited_handler) 91 | -------------------------------------------------------------------------------- /docs/conventions.md: -------------------------------------------------------------------------------- 1 | # Conventions 2 | 3 | Please check the following conventions if you would like to modify MMDetection as your own project. 4 | 5 | ## Loss 6 | In MMDetection, a `dict` containing losses and metrics will be returned by `model(**data)`. 7 | 8 | For example, in bbox head, 9 | ```python 10 | class BBoxHead(nn.Module): 11 | ... 12 | def loss(self, ...): 13 | losses = dict() 14 | # classification loss 15 | losses['loss_cls'] = self.loss_cls(...) 16 | # classification accuracy 17 | losses['acc'] = accuracy(...) 18 | # bbox regression loss 19 | losses['loss_bbox'] = self.loss_bbox(...) 20 | return losses 21 | ``` 22 | `bbox_head.loss()` will be called during model forward. 23 | The returned dict contains `'loss_bbox'`, `'loss_cls'`, `'acc'` . 24 | Only `'loss_bbox'`, `'loss_cls'` will be used during back propagation, 25 | `'acc'` will only be used as a metric to monitor training process. 26 | 27 | By default, only values whose keys contain `'loss'` will be back propagated. 28 | This behavior could be changed by modifying `BaseDetector.train_step()`. 29 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to MMDetection's documentation! 2 | ======================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Get Started 7 | 8 | get_started.md 9 | modelzoo_statistics.md 10 | model_zoo.md 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: Quick Run 15 | 16 | 1_exist_data_model.md 17 | 2_new_data_model.md 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: Tutorials 22 | 23 | tutorials/index.rst 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | :caption: Useful Tools and Scripts 28 | 29 | useful_tools.md 30 | 31 | .. toctree:: 32 | :maxdepth: 2 33 | :caption: Notes 34 | 35 | conventions.md 36 | compatibility.md 37 | projects.md 38 | changelog.md 39 | faq.md 40 | 41 | .. toctree:: 42 | :caption: API Reference 43 | 44 | api.rst 45 | 46 | Indices and tables 47 | ================== 48 | 49 | * :ref:`genindex` 50 | * :ref:`search` 51 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import glob 3 | import os.path as osp 4 | import re 5 | 6 | url_prefix = 'https://github.com/open-mmlab/mmdetection/blob/master/' 7 | 8 | files = sorted(glob.glob('../configs/*/README.md')) 9 | 10 | stats = [] 11 | titles = [] 12 | num_ckpts = 0 13 | 14 | for f in files: 15 | url = osp.dirname(f.replace('../', url_prefix)) 16 | 17 | with open(f, 'r') as content_file: 18 | content = content_file.read() 19 | 20 | title = content.split('\n')[0].replace('# ', '') 21 | 22 | titles.append(title) 23 | ckpts = set(x.lower().strip() 24 | for x in re.findall(r'https?://download.*\.pth', content) 25 | if 'mmdetection' in x) 26 | num_ckpts += len(ckpts) 27 | statsmsg = f""" 28 | \t* [{title}]({url}) ({len(ckpts)} ckpts) 29 | """ 30 | stats.append((title, ckpts, statsmsg)) 31 | 32 | msglist = '\n'.join(x for _, _, x in stats) 33 | 34 | modelzoo = f""" 35 | # Model Zoo Statistics 36 | 37 | * Number of papers: {len(titles)} 38 | * Number of checkpoints: {num_ckpts} 39 | {msglist} 40 | """ 41 | 42 | with open('modelzoo_statistics.md', 'w') as f: 43 | f.write(modelzoo) 44 | -------------------------------------------------------------------------------- /docs/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | 4 | config.md 5 | customize_dataset.md 6 | data_pipeline.md 7 | customize_models.md 8 | customize_runtime.md 9 | customize_losses.md 10 | finetune.md 11 | -------------------------------------------------------------------------------- /mmdet/__init__.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from .version import __version__, short_version 4 | 5 | 6 | def digit_version(version_str): 7 | digit_version = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | digit_version.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | digit_version.append(int(patch_version[0]) - 1) 14 | digit_version.append(int(patch_version[1])) 15 | return digit_version 16 | 17 | 18 | mmcv_minimum_version = '1.1.5' 19 | mmcv_maximum_version = '1.3' 20 | mmcv_version = digit_version(mmcv.__version__) 21 | 22 | 23 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 24 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 25 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 26 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 27 | 28 | __all__ = ['__version__', 'short_version'] 29 | -------------------------------------------------------------------------------- /mmdet/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import (async_inference_detector, inference_detector, 2 | init_detector, show_result_pyplot) 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import get_root_logger, set_random_seed, train_detector 5 | 6 | __all__ = [ 7 | 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector', 8 | 'async_inference_detector', 'inference_detector', 'show_result_pyplot', 9 | 'multi_gpu_test', 'single_gpu_test' 10 | ] 11 | -------------------------------------------------------------------------------- /mmdet/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor import * # noqa: F401, F403 2 | from .bbox import * # noqa: F401, F403 3 | from .evaluation import * # noqa: F401, F403 4 | from .export import * # noqa: F401, F403 5 | from .fp16 import * # noqa: F401, F403 6 | from .mask import * # noqa: F401, F403 7 | from .post_processing import * # noqa: F401, F403 8 | from .utils import * # noqa: F401, F403 9 | -------------------------------------------------------------------------------- /mmdet/core/anchor/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator, 2 | YOLOAnchorGenerator) 3 | from .builder import ANCHOR_GENERATORS, build_anchor_generator 4 | from .point_generator import PointGenerator 5 | from .utils import anchor_inside_flags, calc_region, images_to_levels 6 | 7 | __all__ = [ 8 | 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags', 9 | 'PointGenerator', 'images_to_levels', 'calc_region', 10 | 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator' 11 | ] 12 | -------------------------------------------------------------------------------- /mmdet/core/anchor/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | 3 | ANCHOR_GENERATORS = Registry('Anchor generator') 4 | 5 | 6 | def build_anchor_generator(cfg, default_args=None): 7 | return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args) 8 | -------------------------------------------------------------------------------- /mmdet/core/anchor/point_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .builder import ANCHOR_GENERATORS 4 | 5 | 6 | @ANCHOR_GENERATORS.register_module() 7 | class PointGenerator(object): 8 | 9 | def _meshgrid(self, x, y, row_major=True): 10 | xx = x.repeat(len(y)) 11 | yy = y.view(-1, 1).repeat(1, len(x)).view(-1) 12 | if row_major: 13 | return xx, yy 14 | else: 15 | return yy, xx 16 | 17 | def grid_points(self, featmap_size, stride=16, device='cuda'): 18 | feat_h, feat_w = featmap_size 19 | shift_x = torch.arange(0., feat_w, device=device) * stride 20 | shift_y = torch.arange(0., feat_h, device=device) * stride 21 | shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) 22 | stride = shift_x.new_full((shift_xx.shape[0], ), stride) 23 | shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1) 24 | all_points = shifts.to(device) 25 | return all_points 26 | 27 | def valid_flags(self, featmap_size, valid_size, device='cuda'): 28 | feat_h, feat_w = featmap_size 29 | valid_h, valid_w = valid_size 30 | assert valid_h <= feat_h and valid_w <= feat_w 31 | valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) 32 | valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) 33 | valid_x[:valid_w] = 1 34 | valid_y[:valid_h] = 1 35 | valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) 36 | valid = valid_xx & valid_yy 37 | return valid 38 | -------------------------------------------------------------------------------- /mmdet/core/anchor/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def images_to_levels(target, num_levels): 5 | """Convert targets by image to targets by feature level. 6 | 7 | [target_img0, target_img1] -> [target_level0, target_level1, ...] 8 | """ 9 | target = torch.stack(target, 0) 10 | level_targets = [] 11 | start = 0 12 | for n in num_levels: 13 | end = start + n 14 | # level_targets.append(target[:, start:end].squeeze(0)) 15 | level_targets.append(target[:, start:end]) 16 | start = end 17 | return level_targets 18 | 19 | 20 | def anchor_inside_flags(flat_anchors, 21 | valid_flags, 22 | img_shape, 23 | allowed_border=0): 24 | """Check whether the anchors are inside the border. 25 | 26 | Args: 27 | flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4). 28 | valid_flags (torch.Tensor): An existing valid flags of anchors. 29 | img_shape (tuple(int)): Shape of current image. 30 | allowed_border (int, optional): The border to allow the valid anchor. 31 | Defaults to 0. 32 | 33 | Returns: 34 | torch.Tensor: Flags indicating whether the anchors are inside a \ 35 | valid range. 36 | """ 37 | img_h, img_w = img_shape[:2] 38 | if allowed_border >= 0: 39 | inside_flags = valid_flags & \ 40 | (flat_anchors[:, 0] >= -allowed_border) & \ 41 | (flat_anchors[:, 1] >= -allowed_border) & \ 42 | (flat_anchors[:, 2] < img_w + allowed_border) & \ 43 | (flat_anchors[:, 3] < img_h + allowed_border) 44 | else: 45 | inside_flags = valid_flags 46 | return inside_flags 47 | 48 | 49 | def calc_region(bbox, ratio, featmap_size=None): 50 | """Calculate a proportional bbox region. 51 | 52 | The bbox center are fixed and the new h' and w' is h * ratio and w * ratio. 53 | 54 | Args: 55 | bbox (Tensor): Bboxes to calculate regions, shape (n, 4). 56 | ratio (float): Ratio of the output region. 57 | featmap_size (tuple): Feature map size used for clipping the boundary. 58 | 59 | Returns: 60 | tuple: x1, y1, x2, y2 61 | """ 62 | x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long() 63 | y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long() 64 | x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long() 65 | y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long() 66 | if featmap_size is not None: 67 | x1 = x1.clamp(min=0, max=featmap_size[1]) 68 | y1 = y1.clamp(min=0, max=featmap_size[0]) 69 | x2 = x2.clamp(min=0, max=featmap_size[1]) 70 | y2 = y2.clamp(min=0, max=featmap_size[0]) 71 | return (x1, y1, x2, y2) 72 | -------------------------------------------------------------------------------- /mmdet/core/bbox/__init__.py: -------------------------------------------------------------------------------- 1 | from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner, 2 | MaxIoUAssigner) 3 | from .builder import build_assigner, build_bbox_coder, build_sampler 4 | from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder, 5 | TBLRBBoxCoder) 6 | from .iou_calculators import BboxOverlaps2D, bbox_overlaps 7 | from .samplers import (BaseSampler, CombinedSampler, 8 | InstanceBalancedPosSampler, IoUBalancedNegSampler, 9 | OHEMSampler, PseudoSampler, RandomSampler, 10 | SamplingResult, ScoreHLRSampler) 11 | from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip, 12 | bbox_mapping, bbox_mapping_back, bbox_rescale, 13 | distance2bbox, roi2bbox) 14 | 15 | __all__ = [ 16 | 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', 17 | 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler', 18 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 19 | 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner', 20 | 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 21 | 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', 22 | 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', 23 | 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner', 24 | 'bbox_rescale' 25 | ] 26 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assigners/__init__.py: -------------------------------------------------------------------------------- 1 | from .approx_max_iou_assigner import ApproxMaxIoUAssigner 2 | from .assign_result import AssignResult 3 | from .atss_assigner import ATSSAssigner 4 | from .base_assigner import BaseAssigner 5 | from .center_region_assigner import CenterRegionAssigner 6 | from .grid_assigner import GridAssigner 7 | from .max_iou_assigner import MaxIoUAssigner 8 | from .point_assigner import PointAssigner 9 | 10 | __all__ = [ 11 | 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 12 | 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner' 13 | ] 14 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assigners/base_assigner.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class BaseAssigner(metaclass=ABCMeta): 5 | """Base assigner that assigns boxes to ground truth boxes.""" 6 | 7 | @abstractmethod 8 | def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): 9 | """Assign boxes to either a ground truth boxe or a negative boxes.""" 10 | pass 11 | -------------------------------------------------------------------------------- /mmdet/core/bbox/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | 3 | BBOX_ASSIGNERS = Registry('bbox_assigner') 4 | BBOX_SAMPLERS = Registry('bbox_sampler') 5 | BBOX_CODERS = Registry('bbox_coder') 6 | 7 | 8 | def build_assigner(cfg, **default_args): 9 | """Builder of box assigner.""" 10 | return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args) 11 | 12 | 13 | def build_sampler(cfg, **default_args): 14 | """Builder of box sampler.""" 15 | return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) 16 | 17 | 18 | def build_bbox_coder(cfg, **default_args): 19 | """Builder of box coder.""" 20 | return build_from_cfg(cfg, BBOX_CODERS, default_args) 21 | -------------------------------------------------------------------------------- /mmdet/core/bbox/coder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_bbox_coder import BaseBBoxCoder 2 | from .bucketing_bbox_coder import BucketingBBoxCoder 3 | from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder 4 | from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder 5 | from .pseudo_bbox_coder import PseudoBBoxCoder 6 | from .tblr_bbox_coder import TBLRBBoxCoder 7 | from .yolo_bbox_coder import YOLOBBoxCoder 8 | 9 | __all__ = [ 10 | 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 11 | 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder', 12 | 'BucketingBBoxCoder' 13 | ] 14 | -------------------------------------------------------------------------------- /mmdet/core/bbox/coder/base_bbox_coder.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class BaseBBoxCoder(metaclass=ABCMeta): 5 | """Base bounding box coder.""" 6 | 7 | def __init__(self, **kwargs): 8 | pass 9 | 10 | @abstractmethod 11 | def encode(self, bboxes, gt_bboxes): 12 | """Encode deltas between bboxes and ground truth boxes.""" 13 | pass 14 | 15 | @abstractmethod 16 | def decode(self, bboxes, bboxes_pred): 17 | """Decode the predicted bboxes according to prediction and base 18 | boxes.""" 19 | pass 20 | -------------------------------------------------------------------------------- /mmdet/core/bbox/coder/pseudo_bbox_coder.py: -------------------------------------------------------------------------------- 1 | from ..builder import BBOX_CODERS 2 | from .base_bbox_coder import BaseBBoxCoder 3 | 4 | 5 | @BBOX_CODERS.register_module() 6 | class PseudoBBoxCoder(BaseBBoxCoder): 7 | """Pseudo bounding box coder.""" 8 | 9 | def __init__(self, **kwargs): 10 | super(BaseBBoxCoder, self).__init__(**kwargs) 11 | 12 | def encode(self, bboxes, gt_bboxes): 13 | """torch.Tensor: return the given ``bboxes``""" 14 | return gt_bboxes 15 | 16 | def decode(self, bboxes, pred_bboxes): 17 | """torch.Tensor: return the given ``pred_bboxes``""" 18 | return pred_bboxes 19 | -------------------------------------------------------------------------------- /mmdet/core/bbox/demodata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def ensure_rng(rng=None): 6 | """Simple version of the ``kwarray.ensure_rng`` 7 | 8 | Args: 9 | rng (int | numpy.random.RandomState | None): 10 | if None, then defaults to the global rng. Otherwise this can be an 11 | integer or a RandomState class 12 | Returns: 13 | (numpy.random.RandomState) : rng - 14 | a numpy random number generator 15 | 16 | References: 17 | https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 18 | """ 19 | 20 | if rng is None: 21 | rng = np.random.mtrand._rand 22 | elif isinstance(rng, int): 23 | rng = np.random.RandomState(rng) 24 | else: 25 | rng = rng 26 | return rng 27 | 28 | 29 | def random_boxes(num=1, scale=1, rng=None): 30 | """Simple version of ``kwimage.Boxes.random`` 31 | 32 | Returns: 33 | Tensor: shape (n, 4) in x1, y1, x2, y2 format. 34 | 35 | References: 36 | https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 37 | 38 | Example: 39 | >>> num = 3 40 | >>> scale = 512 41 | >>> rng = 0 42 | >>> boxes = random_boxes(num, scale, rng) 43 | >>> print(boxes) 44 | tensor([[280.9925, 278.9802, 308.6148, 366.1769], 45 | [216.9113, 330.6978, 224.0446, 456.5878], 46 | [405.3632, 196.3221, 493.3953, 270.7942]]) 47 | """ 48 | rng = ensure_rng(rng) 49 | 50 | tlbr = rng.rand(num, 4).astype(np.float32) 51 | 52 | tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2]) 53 | tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3]) 54 | br_x = np.maximum(tlbr[:, 0], tlbr[:, 2]) 55 | br_y = np.maximum(tlbr[:, 1], tlbr[:, 3]) 56 | 57 | tlbr[:, 0] = tl_x * scale 58 | tlbr[:, 1] = tl_y * scale 59 | tlbr[:, 2] = br_x * scale 60 | tlbr[:, 3] = br_y * scale 61 | 62 | boxes = torch.from_numpy(tlbr) 63 | return boxes 64 | -------------------------------------------------------------------------------- /mmdet/core/bbox/iou_calculators/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_iou_calculator 2 | from .iou2d_calculator import BboxOverlaps2D, bbox_overlaps 3 | 4 | __all__ = ['build_iou_calculator', 'BboxOverlaps2D', 'bbox_overlaps'] 5 | -------------------------------------------------------------------------------- /mmdet/core/bbox/iou_calculators/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | 3 | IOU_CALCULATORS = Registry('IoU calculator') 4 | 5 | 6 | def build_iou_calculator(cfg, default_args=None): 7 | """Builder of IoU calculator.""" 8 | return build_from_cfg(cfg, IOU_CALCULATORS, default_args) 9 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | from .combined_sampler import CombinedSampler 3 | from .instance_balanced_pos_sampler import InstanceBalancedPosSampler 4 | from .iou_balanced_neg_sampler import IoUBalancedNegSampler 5 | from .ohem_sampler import OHEMSampler 6 | from .pseudo_sampler import PseudoSampler 7 | from .random_sampler import RandomSampler 8 | from .sampling_result import SamplingResult 9 | from .score_hlr_sampler import ScoreHLRSampler 10 | 11 | __all__ = [ 12 | 'BaseSampler', 'PseudoSampler', 'RandomSampler', 13 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 14 | 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler' 15 | ] 16 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/combined_sampler.py: -------------------------------------------------------------------------------- 1 | from ..builder import BBOX_SAMPLERS, build_sampler 2 | from .base_sampler import BaseSampler 3 | 4 | 5 | @BBOX_SAMPLERS.register_module() 6 | class CombinedSampler(BaseSampler): 7 | """A sampler that combines positive sampler and negative sampler.""" 8 | 9 | def __init__(self, pos_sampler, neg_sampler, **kwargs): 10 | super(CombinedSampler, self).__init__(**kwargs) 11 | self.pos_sampler = build_sampler(pos_sampler, **kwargs) 12 | self.neg_sampler = build_sampler(neg_sampler, **kwargs) 13 | 14 | def _sample_pos(self, **kwargs): 15 | """Sample positive samples.""" 16 | raise NotImplementedError 17 | 18 | def _sample_neg(self, **kwargs): 19 | """Sample negative samples.""" 20 | raise NotImplementedError 21 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..builder import BBOX_SAMPLERS 5 | from .random_sampler import RandomSampler 6 | 7 | 8 | @BBOX_SAMPLERS.register_module() 9 | class InstanceBalancedPosSampler(RandomSampler): 10 | """Instance balanced sampler that samples equal number of positive samples 11 | for each instance.""" 12 | 13 | def _sample_pos(self, assign_result, num_expected, **kwargs): 14 | """Sample positive boxes. 15 | 16 | Args: 17 | assign_result (:obj:`AssignResult`): The assigned results of boxes. 18 | num_expected (int): The number of expected positive samples 19 | 20 | Returns: 21 | Tensor or ndarray: sampled indices. 22 | """ 23 | pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) 24 | if pos_inds.numel() != 0: 25 | pos_inds = pos_inds.squeeze(1) 26 | if pos_inds.numel() <= num_expected: 27 | return pos_inds 28 | else: 29 | unique_gt_inds = assign_result.gt_inds[pos_inds].unique() 30 | num_gts = len(unique_gt_inds) 31 | num_per_gt = int(round(num_expected / float(num_gts)) + 1) 32 | sampled_inds = [] 33 | for i in unique_gt_inds: 34 | inds = torch.nonzero( 35 | assign_result.gt_inds == i.item(), as_tuple=False) 36 | if inds.numel() != 0: 37 | inds = inds.squeeze(1) 38 | else: 39 | continue 40 | if len(inds) > num_per_gt: 41 | inds = self.random_choice(inds, num_per_gt) 42 | sampled_inds.append(inds) 43 | sampled_inds = torch.cat(sampled_inds) 44 | if len(sampled_inds) < num_expected: 45 | num_extra = num_expected - len(sampled_inds) 46 | extra_inds = np.array( 47 | list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) 48 | if len(extra_inds) > num_extra: 49 | extra_inds = self.random_choice(extra_inds, num_extra) 50 | extra_inds = torch.from_numpy(extra_inds).to( 51 | assign_result.gt_inds.device).long() 52 | sampled_inds = torch.cat([sampled_inds, extra_inds]) 53 | elif len(sampled_inds) > num_expected: 54 | sampled_inds = self.random_choice(sampled_inds, num_expected) 55 | return sampled_inds 56 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/pseudo_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..builder import BBOX_SAMPLERS 4 | from .base_sampler import BaseSampler 5 | from .sampling_result import SamplingResult 6 | 7 | 8 | @BBOX_SAMPLERS.register_module() 9 | class PseudoSampler(BaseSampler): 10 | """A pseudo sampler that does not do sampling actually.""" 11 | 12 | def __init__(self, **kwargs): 13 | pass 14 | 15 | def _sample_pos(self, **kwargs): 16 | """Sample positive samples.""" 17 | raise NotImplementedError 18 | 19 | def _sample_neg(self, **kwargs): 20 | """Sample negative samples.""" 21 | raise NotImplementedError 22 | 23 | def sample(self, assign_result, bboxes, gt_bboxes, **kwargs): 24 | """Directly returns the positive and negative indices of samples. 25 | 26 | Args: 27 | assign_result (:obj:`AssignResult`): Assigned results 28 | bboxes (torch.Tensor): Bounding boxes 29 | gt_bboxes (torch.Tensor): Ground truth boxes 30 | 31 | Returns: 32 | :obj:`SamplingResult`: sampler results 33 | """ 34 | pos_inds = torch.nonzero( 35 | assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() 36 | neg_inds = torch.nonzero( 37 | assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() 38 | gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8) 39 | sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, 40 | assign_result, gt_flags) 41 | return sampling_result 42 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/random_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..builder import BBOX_SAMPLERS 4 | from .base_sampler import BaseSampler 5 | 6 | 7 | @BBOX_SAMPLERS.register_module() 8 | class RandomSampler(BaseSampler): 9 | """Random sampler. 10 | 11 | Args: 12 | num (int): Number of samples 13 | pos_fraction (float): Fraction of positive samples 14 | neg_pos_up (int, optional): Upper bound number of negative and 15 | positive samples. Defaults to -1. 16 | add_gt_as_proposals (bool, optional): Whether to add ground truth 17 | boxes as proposals. Defaults to True. 18 | """ 19 | 20 | def __init__(self, 21 | num, 22 | pos_fraction, 23 | neg_pos_ub=-1, 24 | add_gt_as_proposals=True, 25 | **kwargs): 26 | from mmdet.core.bbox import demodata 27 | super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, 28 | add_gt_as_proposals) 29 | self.rng = demodata.ensure_rng(kwargs.get('rng', None)) 30 | 31 | def random_choice(self, gallery, num): 32 | """Random select some elements from the gallery. 33 | 34 | If `gallery` is a Tensor, the returned indices will be a Tensor; 35 | If `gallery` is a ndarray or list, the returned indices will be a 36 | ndarray. 37 | 38 | Args: 39 | gallery (Tensor | ndarray | list): indices pool. 40 | num (int): expected sample num. 41 | 42 | Returns: 43 | Tensor or ndarray: sampled indices. 44 | """ 45 | assert len(gallery) >= num 46 | 47 | is_tensor = isinstance(gallery, torch.Tensor) 48 | if not is_tensor: 49 | if torch.cuda.is_available(): 50 | device = torch.cuda.current_device() 51 | else: 52 | device = 'cpu' 53 | gallery = torch.tensor(gallery, dtype=torch.long, device=device) 54 | perm = torch.randperm(gallery.numel(), device=gallery.device)[:num] 55 | rand_inds = gallery[perm] 56 | if not is_tensor: 57 | rand_inds = rand_inds.cpu().numpy() 58 | return rand_inds 59 | 60 | def _sample_pos(self, assign_result, num_expected, **kwargs): 61 | """Randomly sample some positive samples.""" 62 | pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) 63 | if pos_inds.numel() != 0: 64 | pos_inds = pos_inds.squeeze(1) 65 | if pos_inds.numel() <= num_expected: 66 | return pos_inds 67 | else: 68 | return self.random_choice(pos_inds, num_expected) 69 | 70 | def _sample_neg(self, assign_result, num_expected, **kwargs): 71 | """Randomly sample some negative samples.""" 72 | neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) 73 | if neg_inds.numel() != 0: 74 | neg_inds = neg_inds.squeeze(1) 75 | if len(neg_inds) <= num_expected: 76 | return neg_inds 77 | else: 78 | return self.random_choice(neg_inds, num_expected) 79 | -------------------------------------------------------------------------------- /mmdet/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .class_names import (cityscapes_classes, coco_classes, dataset_aliases, 2 | get_classes, imagenet_det_classes, 3 | imagenet_vid_classes, voc_classes) 4 | from .eval_hooks import DistEvalHook, EvalHook 5 | from .mean_ap import average_precision, eval_map, print_map_summary 6 | from .recall import (eval_recalls, plot_iou_recall, plot_num_recall, 7 | print_recall_summary) 8 | 9 | __all__ = [ 10 | 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', 11 | 'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes', 12 | 'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map', 13 | 'print_map_summary', 'eval_recalls', 'print_recall_summary', 14 | 'plot_num_recall', 'plot_iou_recall' 15 | ] 16 | -------------------------------------------------------------------------------- /mmdet/core/evaluation/bbox_overlaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6): 5 | """Calculate the ious between each bbox of bboxes1 and bboxes2. 6 | 7 | Args: 8 | bboxes1(ndarray): shape (n, 4) 9 | bboxes2(ndarray): shape (k, 4) 10 | mode(str): iou (intersection over union) or iof (intersection 11 | over foreground) 12 | 13 | Returns: 14 | ious(ndarray): shape (n, k) 15 | """ 16 | 17 | assert mode in ['iou', 'iof'] 18 | 19 | bboxes1 = bboxes1.astype(np.float32) 20 | bboxes2 = bboxes2.astype(np.float32) 21 | rows = bboxes1.shape[0] 22 | cols = bboxes2.shape[0] 23 | ious = np.zeros((rows, cols), dtype=np.float32) 24 | if rows * cols == 0: 25 | return ious 26 | exchange = False 27 | if bboxes1.shape[0] > bboxes2.shape[0]: 28 | bboxes1, bboxes2 = bboxes2, bboxes1 29 | ious = np.zeros((cols, rows), dtype=np.float32) 30 | exchange = True 31 | area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) 32 | area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) 33 | for i in range(bboxes1.shape[0]): 34 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 35 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 36 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 37 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 38 | overlap = np.maximum(x_end - x_start, 0) * np.maximum( 39 | y_end - y_start, 0) 40 | if mode == 'iou': 41 | union = area1[i] + area2 - overlap 42 | else: 43 | union = area1[i] if not exchange else area2 44 | union = np.maximum(union, eps) 45 | ious[i, :] = overlap / union 46 | if exchange: 47 | ious = ious.T 48 | return ious 49 | -------------------------------------------------------------------------------- /mmdet/core/export/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch2onnx import (build_model_from_cfg, 2 | generate_inputs_and_wrap_model, 3 | preprocess_example_input) 4 | 5 | __all__ = [ 6 | 'build_model_from_cfg', 'generate_inputs_and_wrap_model', 7 | 'preprocess_example_input' 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/core/fp16/__init__.py: -------------------------------------------------------------------------------- 1 | from .deprecated_fp16_utils import \ 2 | DeprecatedFp16OptimizerHook as Fp16OptimizerHook 3 | from .deprecated_fp16_utils import deprecated_auto_fp16 as auto_fp16 4 | from .deprecated_fp16_utils import deprecated_force_fp32 as force_fp32 5 | from .deprecated_fp16_utils import \ 6 | deprecated_wrap_fp16_model as wrap_fp16_model 7 | 8 | __all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model'] 9 | -------------------------------------------------------------------------------- /mmdet/core/fp16/deprecated_fp16_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from mmcv.runner import (Fp16OptimizerHook, auto_fp16, force_fp32, 4 | wrap_fp16_model) 5 | 6 | 7 | class DeprecatedFp16OptimizerHook(Fp16OptimizerHook): 8 | """A wrapper class for the FP16 optimizer hook. This class wraps 9 | :class:`Fp16OptimizerHook` in `mmcv.runner` and shows a warning that the 10 | :class:`Fp16OptimizerHook` from `mmdet.core` will be deprecated. 11 | 12 | Refer to :class:`Fp16OptimizerHook` in `mmcv.runner` for more details. 13 | 14 | Args: 15 | loss_scale (float): Scale factor multiplied with loss. 16 | """ 17 | 18 | def __init__(*args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | warnings.warn( 21 | 'Importing Fp16OptimizerHook from "mmdet.core" will be ' 22 | 'deprecated in the future. Please import them from "mmcv.runner" ' 23 | 'instead') 24 | 25 | 26 | def deprecated_auto_fp16(*args, **kwargs): 27 | warnings.warn( 28 | 'Importing auto_fp16 from "mmdet.core" will be ' 29 | 'deprecated in the future. Please import them from "mmcv.runner" ' 30 | 'instead') 31 | return auto_fp16(*args, **kwargs) 32 | 33 | 34 | def deprecated_force_fp32(*args, **kwargs): 35 | warnings.warn( 36 | 'Importing force_fp32 from "mmdet.core" will be ' 37 | 'deprecated in the future. Please import them from "mmcv.runner" ' 38 | 'instead') 39 | return force_fp32(*args, **kwargs) 40 | 41 | 42 | def deprecated_wrap_fp16_model(*args, **kwargs): 43 | warnings.warn( 44 | 'Importing wrap_fp16_model from "mmdet.core" will be ' 45 | 'deprecated in the future. Please import them from "mmcv.runner" ' 46 | 'instead') 47 | wrap_fp16_model(*args, **kwargs) 48 | -------------------------------------------------------------------------------- /mmdet/core/mask/__init__.py: -------------------------------------------------------------------------------- 1 | from .mask_target import mask_target 2 | from .structures import BaseInstanceMasks, BitmapMasks, PolygonMasks 3 | from .utils import encode_mask_results, split_combined_polys 4 | 5 | __all__ = [ 6 | 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks', 7 | 'PolygonMasks', 'encode_mask_results' 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/core/mask/mask_target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn.modules.utils import _pair 4 | 5 | 6 | def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, 7 | cfg): 8 | """Compute mask target for positive proposals in multiple images. 9 | 10 | Args: 11 | pos_proposals_list (list[Tensor]): Positive proposals in multiple 12 | images. 13 | pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each 14 | positive proposals. 15 | gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of 16 | each image. 17 | cfg (dict): Config dict that specifies the mask size. 18 | 19 | Returns: 20 | list[Tensor]: Mask target of each image. 21 | """ 22 | cfg_list = [cfg for _ in range(len(pos_proposals_list))] 23 | mask_targets = map(mask_target_single, pos_proposals_list, 24 | pos_assigned_gt_inds_list, gt_masks_list, cfg_list) 25 | mask_targets = list(mask_targets) 26 | if len(mask_targets) > 0: 27 | mask_targets = torch.cat(mask_targets) 28 | return mask_targets 29 | 30 | 31 | def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg): 32 | """Compute mask target for each positive proposal in the image. 33 | 34 | Args: 35 | pos_proposals (Tensor): Positive proposals. 36 | pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals. 37 | gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap 38 | or Polygon. 39 | cfg (dict): Config dict that indicate the mask size. 40 | 41 | Returns: 42 | Tensor: Mask target of each positive proposals in the image. 43 | """ 44 | device = pos_proposals.device 45 | mask_size = _pair(cfg.mask_size) 46 | num_pos = pos_proposals.size(0) 47 | if num_pos > 0: 48 | proposals_np = pos_proposals.cpu().numpy() 49 | maxh, maxw = gt_masks.height, gt_masks.width 50 | proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw) 51 | proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh) 52 | pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() 53 | 54 | mask_targets = gt_masks.crop_and_resize( 55 | proposals_np, mask_size, device=device, 56 | inds=pos_assigned_gt_inds).to_ndarray() 57 | 58 | mask_targets = torch.from_numpy(mask_targets).float().to(device) 59 | else: 60 | mask_targets = pos_proposals.new_zeros((0, ) + mask_size) 61 | 62 | return mask_targets 63 | -------------------------------------------------------------------------------- /mmdet/core/mask/utils.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | import pycocotools.mask as mask_util 4 | 5 | 6 | def split_combined_polys(polys, poly_lens, polys_per_mask): 7 | """Split the combined 1-D polys into masks. 8 | 9 | A mask is represented as a list of polys, and a poly is represented as 10 | a 1-D array. In dataset, all masks are concatenated into a single 1-D 11 | tensor. Here we need to split the tensor into original representations. 12 | 13 | Args: 14 | polys (list): a list (length = image num) of 1-D tensors 15 | poly_lens (list): a list (length = image num) of poly length 16 | polys_per_mask (list): a list (length = image num) of poly number 17 | of each mask 18 | 19 | Returns: 20 | list: a list (length = image num) of list (length = mask num) of \ 21 | list (length = poly num) of numpy array. 22 | """ 23 | mask_polys_list = [] 24 | for img_id in range(len(polys)): 25 | polys_single = polys[img_id] 26 | polys_lens_single = poly_lens[img_id].tolist() 27 | polys_per_mask_single = polys_per_mask[img_id].tolist() 28 | 29 | split_polys = mmcv.slice_list(polys_single, polys_lens_single) 30 | mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single) 31 | mask_polys_list.append(mask_polys) 32 | return mask_polys_list 33 | 34 | 35 | # TODO: move this function to more proper place 36 | def encode_mask_results(mask_results): 37 | """Encode bitmap mask to RLE code. 38 | 39 | Args: 40 | mask_results (list | tuple[list]): bitmap mask results. 41 | In mask scoring rcnn, mask_results is a tuple of (segm_results, 42 | segm_cls_score). 43 | 44 | Returns: 45 | list | tuple: RLE encoded mask. 46 | """ 47 | if isinstance(mask_results, tuple): # mask scoring 48 | cls_segms, cls_mask_scores = mask_results 49 | else: 50 | cls_segms = mask_results 51 | num_classes = len(cls_segms) 52 | encoded_mask_results = [[] for _ in range(num_classes)] 53 | for i in range(len(cls_segms)): 54 | for cls_segm in cls_segms[i]: 55 | encoded_mask_results[i].append( 56 | mask_util.encode( 57 | np.array( 58 | cls_segm[:, :, np.newaxis], order='F', 59 | dtype='uint8'))[0]) # encoded with RLE 60 | if isinstance(mask_results, tuple): 61 | return encoded_mask_results, cls_mask_scores 62 | else: 63 | return encoded_mask_results 64 | -------------------------------------------------------------------------------- /mmdet/core/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_nms import fast_nms, multiclass_nms, WeaklyMulticlassNMS 2 | from .merge_augs import (merge_aug_bboxes, merge_aug_masks, 3 | merge_aug_proposals, merge_aug_scores) 4 | 5 | __all__ = [ 6 | 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes', 7 | 'merge_aug_scores', 'merge_aug_masks', 'fast_nms', 'WeaklyMulticlassNMS' 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist_utils import DistOptimizerHook, allreduce_grads, reduce_mean 2 | from .misc import multi_apply, unmap 3 | 4 | __all__ = [ 5 | 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply', 6 | 'unmap' 7 | ] 8 | -------------------------------------------------------------------------------- /mmdet/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | 4 | import torch.distributed as dist 5 | from mmcv.runner import OptimizerHook 6 | from torch._utils import (_flatten_dense_tensors, _take_tensors, 7 | _unflatten_dense_tensors) 8 | 9 | 10 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 11 | if bucket_size_mb > 0: 12 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 13 | buckets = _take_tensors(tensors, bucket_size_bytes) 14 | else: 15 | buckets = OrderedDict() 16 | for tensor in tensors: 17 | tp = tensor.type() 18 | if tp not in buckets: 19 | buckets[tp] = [] 20 | buckets[tp].append(tensor) 21 | buckets = buckets.values() 22 | 23 | for bucket in buckets: 24 | flat_tensors = _flatten_dense_tensors(bucket) 25 | dist.all_reduce(flat_tensors) 26 | flat_tensors.div_(world_size) 27 | for tensor, synced in zip( 28 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)): 29 | tensor.copy_(synced) 30 | 31 | 32 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 33 | """Allreduce gradients. 34 | 35 | Args: 36 | params (list[torch.Parameters]): List of parameters of a model 37 | coalesce (bool, optional): Whether allreduce parameters as a whole. 38 | Defaults to True. 39 | bucket_size_mb (int, optional): Size of bucket, the unit is MB. 40 | Defaults to -1. 41 | """ 42 | grads = [ 43 | param.grad.data for param in params 44 | if param.requires_grad and param.grad is not None 45 | ] 46 | world_size = dist.get_world_size() 47 | if coalesce: 48 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 49 | else: 50 | for tensor in grads: 51 | dist.all_reduce(tensor.div_(world_size)) 52 | 53 | 54 | class DistOptimizerHook(OptimizerHook): 55 | """Deprecated optimizer hook for distributed training.""" 56 | 57 | def __init__(self, *args, **kwargs): 58 | warnings.warn('"DistOptimizerHook" is deprecated, please switch to' 59 | '"mmcv.runner.OptimizerHook".') 60 | super().__init__(*args, **kwargs) 61 | 62 | 63 | def reduce_mean(tensor): 64 | """"Obtain the mean of tensor on different GPUs.""" 65 | if not (dist.is_available() and dist.is_initialized()): 66 | return tensor 67 | tensor = tensor.clone() 68 | dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) 69 | return tensor 70 | -------------------------------------------------------------------------------- /mmdet/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from six.moves import map, zip 5 | 6 | 7 | def multi_apply(func, *args, **kwargs): 8 | """Apply function to a list of arguments. 9 | 10 | Note: 11 | This function applies the ``func`` to multiple inputs and 12 | map the multiple outputs of the ``func`` into different 13 | list. Each list contains the same type of outputs corresponding 14 | to different inputs. 15 | 16 | Args: 17 | func (Function): A function that will be applied to a list of 18 | arguments 19 | 20 | Returns: 21 | tuple(list): A tuple containing multiple list, each list contains \ 22 | a kind of returned results by the function 23 | """ 24 | pfunc = partial(func, **kwargs) if kwargs else func 25 | map_results = map(pfunc, *args) 26 | return tuple(map(list, zip(*map_results))) 27 | 28 | 29 | def unmap(data, count, inds, fill=0): 30 | """Unmap a subset of item (data) back to the original set of items (of size 31 | count)""" 32 | if data.dim() == 1: 33 | ret = data.new_full((count, ), fill) 34 | ret[inds.type(torch.bool)] = data 35 | else: 36 | new_size = (count, ) + data.size()[1:] 37 | ret = data.new_full(new_size, fill) 38 | ret[inds.type(torch.bool), :] = data 39 | return ret 40 | -------------------------------------------------------------------------------- /mmdet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 2 | from .cityscapes import CityscapesDataset 3 | from .coco import CocoDataset 4 | from .custom import CustomDataset 5 | from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, 6 | RepeatDataset) 7 | from .deepfashion import DeepFashionDataset 8 | from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset 9 | from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler 10 | from .utils import replace_ImageToTensor 11 | from .voc import VOCDataset 12 | from .wider_face import WIDERFaceDataset 13 | from .xml_style import XMLDataset 14 | 15 | from .voc_ss import VOCSSDataset 16 | 17 | __all__ = [ 18 | 'CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset', 19 | 'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', 20 | 'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler', 21 | 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 22 | 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES', 23 | 'build_dataset', 'replace_ImageToTensor', 24 | 'VOCSSDataset' 25 | ] 26 | -------------------------------------------------------------------------------- /mmdet/datasets/deepfashion.py: -------------------------------------------------------------------------------- 1 | from .builder import DATASETS 2 | from .coco import CocoDataset 3 | 4 | 5 | @DATASETS.register_module() 6 | class DeepFashionDataset(CocoDataset): 7 | 8 | CLASSES = ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants', 'bag', 9 | 'neckwear', 'headwear', 'eyeglass', 'belt', 'footwear', 'hair', 10 | 'skin', 'face') 11 | -------------------------------------------------------------------------------- /mmdet/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import (AutoAugment, BrightnessTransform, ColorTransform, 2 | ContrastTransform, EqualizeTransform, Rotate, Shear, 3 | Translate) 4 | from .compose import Compose 5 | from .formating import (Collect, DefaultFormatBundle, ImageToTensor, 6 | ToDataContainer, ToTensor, Transpose, to_tensor) 7 | from .instaboost import InstaBoost 8 | from .loading import (LoadAnnotations, LoadImageFromFile, LoadImageFromWebcam, 9 | LoadMultiChannelImageFromFiles, LoadProposals) 10 | from .test_time_aug import MultiScaleFlipAug 11 | from .transforms import (Albu, CutOut, Expand, MinIoURandomCrop, Normalize, 12 | Pad, PhotoMetricDistortion, RandomCenterCropPad, 13 | RandomCrop, RandomFlip, Resize, SegRescale) 14 | 15 | __all__ = [ 16 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 17 | 'Transpose', 'Collect', 'DefaultFormatBundle', 'LoadAnnotations', 18 | 'LoadImageFromFile', 'LoadImageFromWebcam', 19 | 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'MultiScaleFlipAug', 20 | 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 21 | 'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu', 22 | 'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear', 23 | 'Rotate', 'ColorTransform', 'EqualizeTransform', 'BrightnessTransform', 24 | 'ContrastTransform', 'Translate' 25 | ] 26 | -------------------------------------------------------------------------------- /mmdet/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | from mmcv.utils import build_from_cfg 4 | 5 | from ..builder import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Compose(object): 10 | """Compose multiple transforms sequentially. 11 | 12 | Args: 13 | transforms (Sequence[dict | callable]): Sequence of transform object or 14 | config dict to be composed. 15 | """ 16 | 17 | def __init__(self, transforms): 18 | assert isinstance(transforms, collections.abc.Sequence) 19 | self.transforms = [] 20 | for transform in transforms: 21 | if isinstance(transform, dict): 22 | transform = build_from_cfg(transform, PIPELINES) 23 | self.transforms.append(transform) 24 | elif callable(transform): 25 | self.transforms.append(transform) 26 | else: 27 | raise TypeError('transform must be callable or a dict') 28 | 29 | def __call__(self, data): 30 | """Call function to apply transforms sequentially. 31 | 32 | Args: 33 | data (dict): A result dict contains the data to transform. 34 | 35 | Returns: 36 | dict: Transformed data. 37 | """ 38 | 39 | for t in self.transforms: 40 | data = t(data) 41 | if data is None: 42 | return None 43 | return data 44 | 45 | def __repr__(self): 46 | format_string = self.__class__.__name__ + '(' 47 | for t in self.transforms: 48 | format_string += '\n' 49 | format_string += f' {t}' 50 | format_string += '\n)' 51 | return format_string 52 | -------------------------------------------------------------------------------- /mmdet/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_sampler import DistributedSampler 2 | from .group_sampler import DistributedGroupSampler, GroupSampler 3 | 4 | __all__ = ['DistributedSampler', 'DistributedGroupSampler', 'GroupSampler'] 5 | -------------------------------------------------------------------------------- /mmdet/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DistributedSampler as _DistributedSampler 3 | 4 | 5 | class DistributedSampler(_DistributedSampler): 6 | 7 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 8 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 9 | self.shuffle = shuffle 10 | 11 | def __iter__(self): 12 | # deterministically shuffle based on epoch 13 | if self.shuffle: 14 | g = torch.Generator() 15 | g.manual_seed(self.epoch) 16 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 17 | else: 18 | indices = torch.arange(len(self.dataset)).tolist() 19 | 20 | # add extra samples to make it evenly divisible 21 | indices += indices[:(self.total_size - len(indices))] 22 | assert len(indices) == self.total_size 23 | 24 | # subsample 25 | indices = indices[self.rank:self.total_size:self.num_replicas] 26 | assert len(indices) == self.num_samples 27 | 28 | return iter(indices) 29 | -------------------------------------------------------------------------------- /mmdet/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | 4 | 5 | def replace_ImageToTensor(pipelines): 6 | """Replace the ImageToTensor transform in a data pipeline to 7 | DefaultFormatBundle, which is normally useful in batch inference. 8 | 9 | Args: 10 | pipelines (list[dict]): Data pipeline configs. 11 | 12 | Returns: 13 | list: The new pipeline list with all ImageToTensor replaced by 14 | DefaultFormatBundle. 15 | 16 | Examples: 17 | >>> pipelines = [ 18 | ... dict(type='LoadImageFromFile'), 19 | ... dict( 20 | ... type='MultiScaleFlipAug', 21 | ... img_scale=(1333, 800), 22 | ... flip=False, 23 | ... transforms=[ 24 | ... dict(type='Resize', keep_ratio=True), 25 | ... dict(type='RandomFlip'), 26 | ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), 27 | ... dict(type='Pad', size_divisor=32), 28 | ... dict(type='ImageToTensor', keys=['img']), 29 | ... dict(type='Collect', keys=['img']), 30 | ... ]) 31 | ... ] 32 | >>> expected_pipelines = [ 33 | ... dict(type='LoadImageFromFile'), 34 | ... dict( 35 | ... type='MultiScaleFlipAug', 36 | ... img_scale=(1333, 800), 37 | ... flip=False, 38 | ... transforms=[ 39 | ... dict(type='Resize', keep_ratio=True), 40 | ... dict(type='RandomFlip'), 41 | ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), 42 | ... dict(type='Pad', size_divisor=32), 43 | ... dict(type='DefaultFormatBundle'), 44 | ... dict(type='Collect', keys=['img']), 45 | ... ]) 46 | ... ] 47 | >>> assert expected_pipelines == replace_ImageToTensor(pipelines) 48 | """ 49 | pipelines = copy.deepcopy(pipelines) 50 | for i, pipeline in enumerate(pipelines): 51 | if pipeline['type'] == 'MultiScaleFlipAug': 52 | assert 'transforms' in pipeline 53 | pipeline['transforms'] = replace_ImageToTensor( 54 | pipeline['transforms']) 55 | elif pipeline['type'] == 'ImageToTensor': 56 | warnings.warn( 57 | '"ImageToTensor" pipeline is replaced by ' 58 | '"DefaultFormatBundle" for batch inference. It is ' 59 | 'recommended to manually replace it in the test ' 60 | 'data pipeline in your config file.', UserWarning) 61 | pipelines[i] = {'type': 'DefaultFormatBundle'} 62 | return pipelines 63 | -------------------------------------------------------------------------------- /mmdet/datasets/voc_ss.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import xml.etree.ElementTree as ET 3 | 4 | from .builder import DATASETS 5 | from .voc import VOCDataset 6 | import mmcv 7 | 8 | 9 | @DATASETS.register_module() 10 | class VOCSSDataset(VOCDataset): 11 | # VOC Dataset with SuperPixels 12 | 13 | def __init__(self, **kwargs): 14 | super(VOCSSDataset, self).__init__(**kwargs) 15 | 16 | def load_annotations(self, ann_file): 17 | data_infos = [] 18 | img_ids = mmcv.list_from_file(ann_file) 19 | for img_id in img_ids: 20 | filename = f'JPEGImages/{img_id}.jpg' 21 | ssname = f'SuperPixels/{img_id}.jpg' 22 | xml_path = osp.join(self.img_prefix, 'Annotations', 23 | f'{img_id}.xml') 24 | tree = ET.parse(xml_path) 25 | root = tree.getroot() 26 | size = root.find('size') 27 | width = 0 28 | height = 0 29 | if size is not None: 30 | width = int(size.find('width').text) 31 | height = int(size.find('height').text) 32 | else: 33 | img_path = osp.join(self.img_prefix, 'JPEGImages', 34 | '{}.jpg'.format(img_id)) 35 | img = Image.open(img_path) 36 | width, height = img.size 37 | data_infos.append( 38 | dict(id=img_id, filename=filename, width=width, height=height, ssname=ssname)) 39 | 40 | return data_infos 41 | 42 | -------------------------------------------------------------------------------- /mmdet/datasets/wider_face.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import xml.etree.ElementTree as ET 3 | 4 | import mmcv 5 | 6 | from .builder import DATASETS 7 | from .xml_style import XMLDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class WIDERFaceDataset(XMLDataset): 12 | """Reader for the WIDER Face dataset in PASCAL VOC format. 13 | 14 | Conversion scripts can be found in 15 | https://github.com/sovrasov/wider-face-pascal-voc-annotations 16 | """ 17 | CLASSES = ('face', ) 18 | 19 | def __init__(self, **kwargs): 20 | super(WIDERFaceDataset, self).__init__(**kwargs) 21 | 22 | def load_annotations(self, ann_file): 23 | """Load annotation from WIDERFace XML style annotation file. 24 | 25 | Args: 26 | ann_file (str): Path of XML file. 27 | 28 | Returns: 29 | list[dict]: Annotation info from XML file. 30 | """ 31 | 32 | data_infos = [] 33 | img_ids = mmcv.list_from_file(ann_file) 34 | for img_id in img_ids: 35 | filename = f'{img_id}.jpg' 36 | xml_path = osp.join(self.img_prefix, 'Annotations', 37 | f'{img_id}.xml') 38 | tree = ET.parse(xml_path) 39 | root = tree.getroot() 40 | size = root.find('size') 41 | width = int(size.find('width').text) 42 | height = int(size.find('height').text) 43 | folder = root.find('folder').text 44 | data_infos.append( 45 | dict( 46 | id=img_id, 47 | filename=osp.join(folder, filename), 48 | width=width, 49 | height=height)) 50 | 51 | return data_infos 52 | -------------------------------------------------------------------------------- /mmdet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401,F403 2 | from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, 3 | ROI_EXTRACTORS, SHARED_HEADS, build_backbone, 4 | build_detector, build_head, build_loss, build_neck, 5 | build_roi_extractor, build_shared_head) 6 | from .dense_heads import * # noqa: F401,F403 7 | from .detectors import * # noqa: F401,F403 8 | from .losses import * # noqa: F401,F403 9 | from .necks import * # noqa: F401,F403 10 | from .roi_heads import * # noqa: F401,F403 11 | 12 | __all__ = [ 13 | 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES', 14 | 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor', 15 | 'build_shared_head', 'build_head', 'build_loss', 'build_detector' 16 | ] 17 | -------------------------------------------------------------------------------- /mmdet/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .darknet import Darknet 2 | from .detectors_resnet import DetectoRS_ResNet 3 | from .detectors_resnext import DetectoRS_ResNeXt 4 | from .hourglass import HourglassNet 5 | from .hrnet import HRNet 6 | from .regnet import RegNet 7 | from .res2net import Res2Net 8 | from .resnest import ResNeSt 9 | from .resnet import ResNet, ResNetV1d 10 | from .resnext import ResNeXt 11 | from .ssd_vgg import SSDVGG 12 | from .vgg import VGG16 13 | 14 | __all__ = [ 15 | 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', 16 | 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', 17 | 'ResNeSt', 'VGG16' 18 | ] 19 | -------------------------------------------------------------------------------- /mmdet/models/backbones/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | from ..builder import BACKBONES 5 | 6 | @BACKBONES.register_module() 7 | class VGG16(nn.Module): 8 | def __init__(self): 9 | super(VGG16, self).__init__() 10 | 11 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True) 12 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True) 13 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) 14 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True) 15 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True) 16 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) 17 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True) 18 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 19 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 20 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) 21 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True) 22 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True) 23 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True) 24 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=True) 25 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=True) 26 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=True) 27 | 28 | self.conv1_1.requires_grad = False 29 | self.conv1_2.requires_grad = False 30 | self.conv2_1.requires_grad = False 31 | self.conv2_2.requires_grad = False 32 | 33 | def forward(self, x): 34 | x = F.relu(self.conv1_1(x), inplace=True) 35 | x = F.relu(self.conv1_2(x), inplace=True) 36 | x = self.pool1(x) 37 | x = F.relu(self.conv2_1(x), inplace=True) 38 | x = F.relu(self.conv2_2(x), inplace=True) 39 | x = self.pool2(x) 40 | x = F.relu(self.conv3_1(x), inplace=True) 41 | x = F.relu(self.conv3_2(x), inplace=True) 42 | x = F.relu(self.conv3_3(x), inplace=True) 43 | x = self.pool3(x) 44 | x = F.relu(self.conv4_1(x), inplace=True) 45 | x = F.relu(self.conv4_2(x), inplace=True) 46 | x = F.relu(self.conv4_3(x), inplace=True) 47 | x = F.relu(self.conv5_1(x), inplace=True) 48 | x = F.relu(self.conv5_2(x), inplace=True) 49 | x = F.relu(self.conv5_3(x), inplace=True) 50 | return [x] 51 | 52 | def init_weights(self, pretrained=None): 53 | pass 54 | -------------------------------------------------------------------------------- /mmdet/models/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | from torch import nn 3 | 4 | BACKBONES = Registry('backbone') 5 | NECKS = Registry('neck') 6 | ROI_EXTRACTORS = Registry('roi_extractor') 7 | SHARED_HEADS = Registry('shared_head') 8 | HEADS = Registry('head') 9 | LOSSES = Registry('loss') 10 | DETECTORS = Registry('detector') 11 | 12 | 13 | def build(cfg, registry, default_args=None): 14 | """Build a module. 15 | 16 | Args: 17 | cfg (dict, list[dict]): The config of modules, is is either a dict 18 | or a list of configs. 19 | registry (:obj:`Registry`): A registry the module belongs to. 20 | default_args (dict, optional): Default arguments to build the module. 21 | Defaults to None. 22 | 23 | Returns: 24 | nn.Module: A built nn module. 25 | """ 26 | if isinstance(cfg, list): 27 | modules = [ 28 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 29 | ] 30 | return nn.Sequential(*modules) 31 | else: 32 | return build_from_cfg(cfg, registry, default_args) 33 | 34 | 35 | def build_backbone(cfg): 36 | """Build backbone.""" 37 | return build(cfg, BACKBONES) 38 | 39 | 40 | def build_neck(cfg): 41 | """Build neck.""" 42 | return build(cfg, NECKS) 43 | 44 | 45 | def build_roi_extractor(cfg): 46 | """Build roi extractor.""" 47 | return build(cfg, ROI_EXTRACTORS) 48 | 49 | 50 | def build_shared_head(cfg): 51 | """Build shared head.""" 52 | return build(cfg, SHARED_HEADS) 53 | 54 | 55 | def build_head(cfg): 56 | """Build head.""" 57 | return build(cfg, HEADS) 58 | 59 | 60 | def build_loss(cfg): 61 | """Build loss.""" 62 | return build(cfg, LOSSES) 63 | 64 | 65 | def build_detector(cfg, train_cfg=None, test_cfg=None): 66 | """Build detector.""" 67 | return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 68 | -------------------------------------------------------------------------------- /mmdet/models/dense_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_free_head import AnchorFreeHead 2 | from .anchor_head import AnchorHead 3 | from .atss_head import ATSSHead 4 | from .centripetal_head import CentripetalHead 5 | from .corner_head import CornerHead 6 | from .fcos_head import FCOSHead 7 | from .fovea_head import FoveaHead 8 | from .free_anchor_retina_head import FreeAnchorRetinaHead 9 | from .fsaf_head import FSAFHead 10 | from .ga_retina_head import GARetinaHead 11 | from .ga_rpn_head import GARPNHead 12 | from .gfl_head import GFLHead 13 | from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead 14 | from .nasfcos_head import NASFCOSHead 15 | from .paa_head import PAAHead 16 | from .pisa_retinanet_head import PISARetinaHead 17 | from .pisa_ssd_head import PISASSDHead 18 | from .reppoints_head import RepPointsHead 19 | from .retina_head import RetinaHead 20 | from .retina_sepbn_head import RetinaSepBNHead 21 | from .rpn_head import RPNHead 22 | from .sabl_retina_head import SABLRetinaHead 23 | from .ssd_head import SSDHead 24 | from .vfnet_head import VFNetHead 25 | from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead 26 | from .yolo_head import YOLOV3Head 27 | 28 | __all__ = [ 29 | 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 30 | 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 31 | 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 32 | 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', 33 | 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead', 34 | 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 35 | 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead' 36 | ] 37 | -------------------------------------------------------------------------------- /mmdet/models/dense_heads/base_dense_head.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class BaseDenseHead(nn.Module, metaclass=ABCMeta): 7 | """Base class for DenseHeads.""" 8 | 9 | def __init__(self): 10 | super(BaseDenseHead, self).__init__() 11 | 12 | @abstractmethod 13 | def loss(self, **kwargs): 14 | """Compute losses of the head.""" 15 | pass 16 | 17 | @abstractmethod 18 | def get_bboxes(self, **kwargs): 19 | """Transform network output for a batch into bbox predictions.""" 20 | pass 21 | 22 | def forward_train(self, 23 | x, 24 | img_metas, 25 | gt_bboxes, 26 | gt_labels=None, 27 | gt_bboxes_ignore=None, 28 | proposal_cfg=None, 29 | **kwargs): 30 | """ 31 | Args: 32 | x (list[Tensor]): Features from FPN. 33 | img_metas (list[dict]): Meta information of each image, e.g., 34 | image size, scaling factor, etc. 35 | gt_bboxes (Tensor): Ground truth bboxes of the image, 36 | shape (num_gts, 4). 37 | gt_labels (Tensor): Ground truth labels of each box, 38 | shape (num_gts,). 39 | gt_bboxes_ignore (Tensor): Ground truth bboxes to be 40 | ignored, shape (num_ignored_gts, 4). 41 | proposal_cfg (mmcv.Config): Test / postprocessing configuration, 42 | if None, test_cfg would be used 43 | 44 | Returns: 45 | tuple: 46 | losses: (dict[str, Tensor]): A dictionary of loss components. 47 | proposal_list (list[Tensor]): Proposals of each image. 48 | """ 49 | outs = self(x) 50 | if gt_labels is None: 51 | loss_inputs = outs + (gt_bboxes, img_metas) 52 | else: 53 | loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) 54 | losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) 55 | if proposal_cfg is None: 56 | return losses 57 | else: 58 | proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) 59 | return losses, proposal_list 60 | -------------------------------------------------------------------------------- /mmdet/models/dense_heads/nasfcos_head.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch.nn as nn 4 | from mmcv.cnn import (ConvModule, Scale, bias_init_with_prob, 5 | caffe2_xavier_init, normal_init) 6 | 7 | from mmdet.models.dense_heads.fcos_head import FCOSHead 8 | from ..builder import HEADS 9 | 10 | 11 | @HEADS.register_module() 12 | class NASFCOSHead(FCOSHead): 13 | """Anchor-free head used in `NASFCOS `_. 14 | 15 | It is quite similar with FCOS head, except for the searched structure of 16 | classification branch and bbox regression branch, where a structure of 17 | "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead. 18 | """ 19 | 20 | def _init_layers(self): 21 | """Initialize layers of the head.""" 22 | dconv3x3_config = dict( 23 | type='DCNv2', 24 | kernel_size=3, 25 | use_bias=True, 26 | deform_groups=2, 27 | padding=1) 28 | conv3x3_config = dict(type='Conv', kernel_size=3, padding=1) 29 | conv1x1_config = dict(type='Conv', kernel_size=1) 30 | 31 | self.arch_config = [ 32 | dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config 33 | ] 34 | self.cls_convs = nn.ModuleList() 35 | self.reg_convs = nn.ModuleList() 36 | for i, op_ in enumerate(self.arch_config): 37 | op = copy.deepcopy(op_) 38 | chn = self.in_channels if i == 0 else self.feat_channels 39 | assert isinstance(op, dict) 40 | use_bias = op.pop('use_bias', False) 41 | padding = op.pop('padding', 0) 42 | kernel_size = op.pop('kernel_size') 43 | module = ConvModule( 44 | chn, 45 | self.feat_channels, 46 | kernel_size, 47 | stride=1, 48 | padding=padding, 49 | norm_cfg=self.norm_cfg, 50 | bias=use_bias, 51 | conv_cfg=op) 52 | 53 | self.cls_convs.append(copy.deepcopy(module)) 54 | self.reg_convs.append(copy.deepcopy(module)) 55 | 56 | self.conv_cls = nn.Conv2d( 57 | self.feat_channels, self.cls_out_channels, 3, padding=1) 58 | self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) 59 | self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) 60 | 61 | self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) 62 | 63 | def init_weights(self): 64 | """Initialize weights of the head.""" 65 | # retinanet_bias_init 66 | bias_cls = bias_init_with_prob(0.01) 67 | normal_init(self.conv_reg, std=0.01) 68 | normal_init(self.conv_centerness, std=0.01) 69 | normal_init(self.conv_cls, std=0.01, bias=bias_cls) 70 | 71 | for branch in [self.cls_convs, self.reg_convs]: 72 | for module in branch.modules(): 73 | if isinstance(module, ConvModule) \ 74 | and isinstance(module.conv, nn.Conv2d): 75 | caffe2_xavier_init(module.conv) 76 | -------------------------------------------------------------------------------- /mmdet/models/dense_heads/rpn_test_mixin.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from mmdet.core import merge_aug_proposals 4 | 5 | if sys.version_info >= (3, 7): 6 | from mmdet.utils.contextmanagers import completed 7 | 8 | 9 | class RPNTestMixin(object): 10 | """Test methods of RPN.""" 11 | 12 | if sys.version_info >= (3, 7): 13 | 14 | async def async_simple_test_rpn(self, x, img_metas): 15 | sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025) 16 | async with completed( 17 | __name__, 'rpn_head_forward', 18 | sleep_interval=sleep_interval): 19 | rpn_outs = self(x) 20 | 21 | proposal_list = self.get_bboxes(*rpn_outs, img_metas) 22 | return proposal_list 23 | 24 | def simple_test_rpn(self, x, img_metas): 25 | """Test without augmentation. 26 | 27 | Args: 28 | x (tuple[Tensor]): Features from the upstream network, each is 29 | a 4D-tensor. 30 | img_metas (list[dict]): Meta info of each image. 31 | 32 | Returns: 33 | list[Tensor]: Proposals of each image. 34 | """ 35 | rpn_outs = self(x) 36 | proposal_list = self.get_bboxes(*rpn_outs, img_metas) 37 | return proposal_list 38 | 39 | def aug_test_rpn(self, feats, img_metas): 40 | samples_per_gpu = len(img_metas[0]) 41 | aug_proposals = [[] for _ in range(samples_per_gpu)] 42 | for x, img_meta in zip(feats, img_metas): 43 | proposal_list = self.simple_test_rpn(x, img_meta) 44 | for i, proposals in enumerate(proposal_list): 45 | aug_proposals[i].append(proposals) 46 | # reorganize the order of 'img_metas' to match the dimensions 47 | # of 'aug_proposals' 48 | aug_img_metas = [] 49 | for i in range(samples_per_gpu): 50 | aug_img_meta = [] 51 | for j in range(len(img_metas)): 52 | aug_img_meta.append(img_metas[j][i]) 53 | aug_img_metas.append(aug_img_meta) 54 | # after merging, proposals will be rescaled to the original image size 55 | merged_proposals = [ 56 | merge_aug_proposals(proposals, aug_img_meta, self.test_cfg) 57 | for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas) 58 | ] 59 | return merged_proposals 60 | -------------------------------------------------------------------------------- /mmdet/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .atss import ATSS 2 | from .base import BaseDetector 3 | from .cascade_rcnn import CascadeRCNN 4 | from .cornernet import CornerNet 5 | from .fast_rcnn import FastRCNN 6 | from .faster_rcnn import FasterRCNN 7 | from .fcos import FCOS 8 | from .fovea import FOVEA 9 | from .fsaf import FSAF 10 | from .gfl import GFL 11 | from .grid_rcnn import GridRCNN 12 | from .htc import HybridTaskCascade 13 | from .mask_rcnn import MaskRCNN 14 | from .mask_scoring_rcnn import MaskScoringRCNN 15 | from .nasfcos import NASFCOS 16 | from .paa import PAA 17 | from .point_rend import PointRend 18 | from .reppoints_detector import RepPointsDetector 19 | from .retinanet import RetinaNet 20 | from .rpn import RPN 21 | from .single_stage import SingleStageDetector 22 | from .two_stage import TwoStageDetector 23 | from .vfnet import VFNet 24 | from .yolact import YOLACT 25 | from .yolo import YOLOV3 26 | 27 | from .weak_rcnn import WeakRCNN 28 | 29 | __all__ = [ 30 | 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 31 | 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 32 | 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 33 | 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 34 | 'YOLOV3', 'YOLACT', 'VFNet', 35 | 'WeakRCNN' 36 | ] 37 | -------------------------------------------------------------------------------- /mmdet/models/detectors/atss.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class ATSS(SingleStageDetector): 7 | """Implementation of `ATSS `_.""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None): 16 | super(ATSS, self).__init__(backbone, neck, bbox_head, train_cfg, 17 | test_cfg, pretrained) 18 | -------------------------------------------------------------------------------- /mmdet/models/detectors/cascade_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class CascadeRCNN(TwoStageDetector): 7 | r"""Implementation of `Cascade R-CNN: Delving into High Quality Object 8 | Detection `_""" 9 | 10 | def __init__(self, 11 | backbone, 12 | neck=None, 13 | rpn_head=None, 14 | roi_head=None, 15 | train_cfg=None, 16 | test_cfg=None, 17 | pretrained=None): 18 | super(CascadeRCNN, self).__init__( 19 | backbone=backbone, 20 | neck=neck, 21 | rpn_head=rpn_head, 22 | roi_head=roi_head, 23 | train_cfg=train_cfg, 24 | test_cfg=test_cfg, 25 | pretrained=pretrained) 26 | 27 | def show_result(self, data, result, **kwargs): 28 | """Show prediction results of the detector.""" 29 | if self.with_mask: 30 | ms_bbox_result, ms_segm_result = result 31 | if isinstance(ms_bbox_result, dict): 32 | result = (ms_bbox_result['ensemble'], 33 | ms_segm_result['ensemble']) 34 | else: 35 | if isinstance(result, dict): 36 | result = result['ensemble'] 37 | return super(CascadeRCNN, self).show_result(data, result, **kwargs) 38 | -------------------------------------------------------------------------------- /mmdet/models/detectors/fast_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class FastRCNN(TwoStageDetector): 7 | """Implementation of `Fast R-CNN `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | roi_head, 12 | train_cfg, 13 | test_cfg, 14 | neck=None, 15 | pretrained=None): 16 | super(FastRCNN, self).__init__( 17 | backbone=backbone, 18 | neck=neck, 19 | roi_head=roi_head, 20 | train_cfg=train_cfg, 21 | test_cfg=test_cfg, 22 | pretrained=pretrained) 23 | 24 | def forward_test(self, imgs, img_metas, proposals, **kwargs): 25 | """ 26 | Args: 27 | imgs (List[Tensor]): the outer list indicates test-time 28 | augmentations and inner Tensor should have a shape NxCxHxW, 29 | which contains all images in the batch. 30 | img_metas (List[List[dict]]): the outer list indicates test-time 31 | augs (multiscale, flip, etc.) and the inner list indicates 32 | images in a batch. 33 | proposals (List[List[Tensor]]): the outer list indicates test-time 34 | augs (multiscale, flip, etc.) and the inner list indicates 35 | images in a batch. The Tensor should have a shape Px4, where 36 | P is the number of proposals. 37 | """ 38 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: 39 | if not isinstance(var, list): 40 | raise TypeError(f'{name} must be a list, but got {type(var)}') 41 | 42 | num_augs = len(imgs) 43 | if num_augs != len(img_metas): 44 | raise ValueError(f'num of augmentations ({len(imgs)}) ' 45 | f'!= num of image meta ({len(img_metas)})') 46 | 47 | if num_augs == 1: 48 | return self.simple_test(imgs[0], img_metas[0], proposals[0], 49 | **kwargs) 50 | else: 51 | # TODO: support test-time augmentation 52 | assert NotImplementedError 53 | -------------------------------------------------------------------------------- /mmdet/models/detectors/faster_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class FasterRCNN(TwoStageDetector): 7 | """Implementation of `Faster R-CNN `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | rpn_head, 12 | roi_head, 13 | train_cfg, 14 | test_cfg, 15 | neck=None, 16 | pretrained=None): 17 | super(FasterRCNN, self).__init__( 18 | backbone=backbone, 19 | neck=neck, 20 | rpn_head=rpn_head, 21 | roi_head=roi_head, 22 | train_cfg=train_cfg, 23 | test_cfg=test_cfg, 24 | pretrained=pretrained) 25 | -------------------------------------------------------------------------------- /mmdet/models/detectors/fcos.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class FCOS(SingleStageDetector): 7 | """Implementation of `FCOS `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None): 16 | super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg, 17 | test_cfg, pretrained) 18 | -------------------------------------------------------------------------------- /mmdet/models/detectors/fovea.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class FOVEA(SingleStageDetector): 7 | """Implementation of `FoveaBox `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None): 16 | super(FOVEA, self).__init__(backbone, neck, bbox_head, train_cfg, 17 | test_cfg, pretrained) 18 | -------------------------------------------------------------------------------- /mmdet/models/detectors/fsaf.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class FSAF(SingleStageDetector): 7 | """Implementation of `FSAF `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None): 16 | super(FSAF, self).__init__(backbone, neck, bbox_head, train_cfg, 17 | test_cfg, pretrained) 18 | -------------------------------------------------------------------------------- /mmdet/models/detectors/gfl.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class GFL(SingleStageDetector): 7 | 8 | def __init__(self, 9 | backbone, 10 | neck, 11 | bbox_head, 12 | train_cfg=None, 13 | test_cfg=None, 14 | pretrained=None): 15 | super(GFL, self).__init__(backbone, neck, bbox_head, train_cfg, 16 | test_cfg, pretrained) 17 | -------------------------------------------------------------------------------- /mmdet/models/detectors/grid_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class GridRCNN(TwoStageDetector): 7 | """Grid R-CNN. 8 | 9 | This detector is the implementation of: 10 | - Grid R-CNN (https://arxiv.org/abs/1811.12030) 11 | - Grid R-CNN Plus: Faster and Better (https://arxiv.org/abs/1906.05688) 12 | """ 13 | 14 | def __init__(self, 15 | backbone, 16 | rpn_head, 17 | roi_head, 18 | train_cfg, 19 | test_cfg, 20 | neck=None, 21 | pretrained=None): 22 | super(GridRCNN, self).__init__( 23 | backbone=backbone, 24 | neck=neck, 25 | rpn_head=rpn_head, 26 | roi_head=roi_head, 27 | train_cfg=train_cfg, 28 | test_cfg=test_cfg, 29 | pretrained=pretrained) 30 | -------------------------------------------------------------------------------- /mmdet/models/detectors/htc.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .cascade_rcnn import CascadeRCNN 3 | 4 | 5 | @DETECTORS.register_module() 6 | class HybridTaskCascade(CascadeRCNN): 7 | """Implementation of `HTC `_""" 8 | 9 | def __init__(self, **kwargs): 10 | super(HybridTaskCascade, self).__init__(**kwargs) 11 | 12 | @property 13 | def with_semantic(self): 14 | """bool: whether the detector has a semantic head""" 15 | return self.roi_head.with_semantic 16 | -------------------------------------------------------------------------------- /mmdet/models/detectors/mask_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class MaskRCNN(TwoStageDetector): 7 | """Implementation of `Mask R-CNN `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | rpn_head, 12 | roi_head, 13 | train_cfg, 14 | test_cfg, 15 | neck=None, 16 | pretrained=None): 17 | super(MaskRCNN, self).__init__( 18 | backbone=backbone, 19 | neck=neck, 20 | rpn_head=rpn_head, 21 | roi_head=roi_head, 22 | train_cfg=train_cfg, 23 | test_cfg=test_cfg, 24 | pretrained=pretrained) 25 | -------------------------------------------------------------------------------- /mmdet/models/detectors/mask_scoring_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class MaskScoringRCNN(TwoStageDetector): 7 | """Mask Scoring RCNN. 8 | 9 | https://arxiv.org/abs/1903.00241 10 | """ 11 | 12 | def __init__(self, 13 | backbone, 14 | rpn_head, 15 | roi_head, 16 | train_cfg, 17 | test_cfg, 18 | neck=None, 19 | pretrained=None): 20 | super(MaskScoringRCNN, self).__init__( 21 | backbone=backbone, 22 | neck=neck, 23 | rpn_head=rpn_head, 24 | roi_head=roi_head, 25 | train_cfg=train_cfg, 26 | test_cfg=test_cfg, 27 | pretrained=pretrained) 28 | -------------------------------------------------------------------------------- /mmdet/models/detectors/nasfcos.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class NASFCOS(SingleStageDetector): 7 | """NAS-FCOS: Fast Neural Architecture Search for Object Detection. 8 | 9 | https://arxiv.org/abs/1906.0442 10 | """ 11 | 12 | def __init__(self, 13 | backbone, 14 | neck, 15 | bbox_head, 16 | train_cfg=None, 17 | test_cfg=None, 18 | pretrained=None): 19 | super(NASFCOS, self).__init__(backbone, neck, bbox_head, train_cfg, 20 | test_cfg, pretrained) 21 | -------------------------------------------------------------------------------- /mmdet/models/detectors/paa.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class PAA(SingleStageDetector): 7 | """Implementation of `PAA `_.""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None): 16 | super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg, 17 | test_cfg, pretrained) 18 | -------------------------------------------------------------------------------- /mmdet/models/detectors/point_rend.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class PointRend(TwoStageDetector): 7 | """PointRend: Image Segmentation as Rendering 8 | 9 | This detector is the implementation of 10 | `PointRend `_. 11 | 12 | """ 13 | 14 | def __init__(self, 15 | backbone, 16 | rpn_head, 17 | roi_head, 18 | train_cfg, 19 | test_cfg, 20 | neck=None, 21 | pretrained=None): 22 | super(PointRend, self).__init__( 23 | backbone=backbone, 24 | neck=neck, 25 | rpn_head=rpn_head, 26 | roi_head=roi_head, 27 | train_cfg=train_cfg, 28 | test_cfg=test_cfg, 29 | pretrained=pretrained) 30 | -------------------------------------------------------------------------------- /mmdet/models/detectors/reppoints_detector.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class RepPointsDetector(SingleStageDetector): 7 | """RepPoints: Point Set Representation for Object Detection. 8 | 9 | This detector is the implementation of: 10 | - RepPoints detector (https://arxiv.org/pdf/1904.11490) 11 | """ 12 | 13 | def __init__(self, 14 | backbone, 15 | neck, 16 | bbox_head, 17 | train_cfg=None, 18 | test_cfg=None, 19 | pretrained=None): 20 | super(RepPointsDetector, 21 | self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, 22 | pretrained) 23 | -------------------------------------------------------------------------------- /mmdet/models/detectors/retinanet.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class RetinaNet(SingleStageDetector): 7 | """Implementation of `RetinaNet `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None): 16 | super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg, 17 | test_cfg, pretrained) 18 | -------------------------------------------------------------------------------- /mmdet/models/detectors/vfnet.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .single_stage import SingleStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class VFNet(SingleStageDetector): 7 | """Implementation of `VarifocalNet 8 | (VFNet).`_""" 9 | 10 | def __init__(self, 11 | backbone, 12 | neck, 13 | bbox_head, 14 | train_cfg=None, 15 | test_cfg=None, 16 | pretrained=None): 17 | super(VFNet, self).__init__(backbone, neck, bbox_head, train_cfg, 18 | test_cfg, pretrained) 19 | -------------------------------------------------------------------------------- /mmdet/models/detectors/weak_rcnn.py: -------------------------------------------------------------------------------- 1 | from ..builder import DETECTORS 2 | from .two_stage import TwoStageDetector 3 | 4 | 5 | @DETECTORS.register_module() 6 | class WeakRCNN(TwoStageDetector): 7 | """Implementation of `Fast R-CNN `_""" 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | roi_head, 13 | train_cfg, 14 | test_cfg, 15 | pretrained=None): 16 | super(WeakRCNN, self).__init__( 17 | backbone=backbone, 18 | neck=neck, 19 | roi_head=roi_head, 20 | train_cfg=train_cfg, 21 | test_cfg=test_cfg, 22 | pretrained=pretrained) 23 | 24 | def forward_train(self, 25 | img, 26 | img_metas, 27 | gt_labels, 28 | proposals=None, 29 | **kwargs): 30 | 31 | x = self.extract_feat(img) 32 | 33 | losses = dict() 34 | 35 | proposal_list = proposals 36 | 37 | roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, 38 | gt_labels, **kwargs) 39 | losses.update(roi_losses) 40 | return losses 41 | 42 | 43 | def forward_test(self, imgs, img_metas, proposals, **kwargs): 44 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: 45 | if not isinstance(var, list): 46 | raise TypeError(f'{name} must be a list, but got {type(var)}') 47 | 48 | num_augs = len(imgs) 49 | if num_augs != len(img_metas): 50 | raise ValueError(f'num of augmentations ({len(imgs)}) ' 51 | f'!= num of image meta ({len(img_metas)})') 52 | 53 | if num_augs == 1: 54 | return self.simple_test(imgs[0], img_metas[0], proposals[0], 55 | **kwargs) 56 | else: 57 | assert imgs[0].size(0) == 1, 'aug test does not support ' \ 58 | 'inference with batch size ' \ 59 | f'{imgs[0].size(0)}' 60 | return self.aug_test(imgs, img_metas, proposals, **kwargs) 61 | 62 | 63 | def aug_test(self, imgs, img_metas, proposal_list, rescale=False): 64 | """Test with augmentations. 65 | 66 | If rescale is False, then returned bboxes and masks will fit the scale 67 | of imgs[0]. 68 | """ 69 | x = self.extract_feats(imgs) 70 | return self.roi_head.aug_test( 71 | x, proposal_list, img_metas, rescale=rescale) 72 | -------------------------------------------------------------------------------- /mmdet/models/detectors/yolo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Western Digital Corporation or its affiliates. 2 | 3 | from ..builder import DETECTORS 4 | from .single_stage import SingleStageDetector 5 | 6 | 7 | @DETECTORS.register_module() 8 | class YOLOV3(SingleStageDetector): 9 | 10 | def __init__(self, 11 | backbone, 12 | neck, 13 | bbox_head, 14 | train_cfg=None, 15 | test_cfg=None, 16 | pretrained=None): 17 | super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg, 18 | test_cfg, pretrained) 19 | -------------------------------------------------------------------------------- /mmdet/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy, accuracy 2 | from .ae_loss import AssociativeEmbeddingLoss 3 | from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss 4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 5 | cross_entropy, mask_cross_entropy) 6 | from .focal_loss import FocalLoss, sigmoid_focal_loss 7 | from .gaussian_focal_loss import GaussianFocalLoss 8 | from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss 9 | from .ghm_loss import GHMC, GHMR 10 | from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss, 11 | bounded_iou_loss, iou_loss) 12 | from .mse_loss import MSELoss, mse_loss 13 | from .pisa_loss import carl_loss, isr_p 14 | from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss 15 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 16 | from .varifocal_loss import VarifocalLoss 17 | 18 | __all__ = [ 19 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 20 | 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss', 21 | 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss', 22 | 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss', 23 | 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 'GHMC', 24 | 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss', 25 | 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss', 26 | 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss', 27 | 'VarifocalLoss' 28 | ] 29 | -------------------------------------------------------------------------------- /mmdet/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def accuracy(pred, target, topk=1, thresh=None): 5 | """Calculate accuracy according to the prediction and target. 6 | 7 | Args: 8 | pred (torch.Tensor): The model prediction, shape (N, num_class) 9 | target (torch.Tensor): The target of each prediction, shape (N, ) 10 | topk (int | tuple[int], optional): If the predictions in ``topk`` 11 | matches the target, the predictions will be regarded as 12 | correct ones. Defaults to 1. 13 | thresh (float, optional): If not None, predictions with scores under 14 | this threshold are considered incorrect. Default to None. 15 | 16 | Returns: 17 | float | tuple[float]: If the input ``topk`` is a single integer, 18 | the function will return a single float as accuracy. If 19 | ``topk`` is a tuple containing multiple integers, the 20 | function will return a tuple containing accuracies of 21 | each ``topk`` number. 22 | """ 23 | assert isinstance(topk, (int, tuple)) 24 | if isinstance(topk, int): 25 | topk = (topk, ) 26 | return_single = True 27 | else: 28 | return_single = False 29 | 30 | maxk = max(topk) 31 | if pred.size(0) == 0: 32 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 33 | return accu[0] if return_single else accu 34 | assert pred.ndim == 2 and target.ndim == 1 35 | assert pred.size(0) == target.size(0) 36 | assert maxk <= pred.size(1), \ 37 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 38 | pred_value, pred_label = pred.topk(maxk, dim=1) 39 | pred_label = pred_label.t() # transpose to shape (maxk, N) 40 | correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) 41 | if thresh is not None: 42 | # Only prediction values larger than thresh are counted as correct 43 | correct = correct & (pred_value > thresh).t() 44 | res = [] 45 | for k in topk: 46 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 47 | res.append(correct_k.mul_(100.0 / pred.size(0))) 48 | return res[0] if return_single else res 49 | 50 | 51 | class Accuracy(nn.Module): 52 | 53 | def __init__(self, topk=(1, ), thresh=None): 54 | """Module to calculate the accuracy. 55 | 56 | Args: 57 | topk (tuple, optional): The criterion used to calculate the 58 | accuracy. Defaults to (1,). 59 | thresh (float, optional): If not None, predictions with scores 60 | under this threshold are considered incorrect. Default to None. 61 | """ 62 | super().__init__() 63 | self.topk = topk 64 | self.thresh = thresh 65 | 66 | def forward(self, pred, target): 67 | """Forward function to calculate accuracy. 68 | 69 | Args: 70 | pred (torch.Tensor): Prediction of models. 71 | target (torch.Tensor): Target for each prediction. 72 | 73 | Returns: 74 | tuple[float]: The accuracies under different topk criterions. 75 | """ 76 | return accuracy(pred, target, self.topk, self.thresh) 77 | -------------------------------------------------------------------------------- /mmdet/models/losses/gaussian_focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..builder import LOSSES 4 | from .utils import weighted_loss 5 | 6 | 7 | @weighted_loss 8 | def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0): 9 | """`Focal Loss `_ for targets in gaussian 10 | distribution. 11 | 12 | Args: 13 | pred (torch.Tensor): The prediction. 14 | gaussian_target (torch.Tensor): The learning target of the prediction 15 | in gaussian distribution. 16 | alpha (float, optional): A balanced form for Focal Loss. 17 | Defaults to 2.0. 18 | gamma (float, optional): The gamma for calculating the modulating 19 | factor. Defaults to 4.0. 20 | """ 21 | eps = 1e-12 22 | pos_weights = gaussian_target.eq(1) 23 | neg_weights = (1 - gaussian_target).pow(gamma) 24 | pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights 25 | neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights 26 | return pos_loss + neg_loss 27 | 28 | 29 | @LOSSES.register_module() 30 | class GaussianFocalLoss(nn.Module): 31 | """GaussianFocalLoss is a variant of focal loss. 32 | 33 | More details can be found in the `paper 34 | `_ 35 | Code is modified from `kp_utils.py 36 | `_ # noqa: E501 37 | Please notice that the target in GaussianFocalLoss is a gaussian heatmap, 38 | not 0/1 binary target. 39 | 40 | Args: 41 | alpha (float): Power of prediction. 42 | gamma (float): Power of target for negtive samples. 43 | reduction (str): Options are "none", "mean" and "sum". 44 | loss_weight (float): Loss weight of current loss. 45 | """ 46 | 47 | def __init__(self, 48 | alpha=2.0, 49 | gamma=4.0, 50 | reduction='mean', 51 | loss_weight=1.0): 52 | super(GaussianFocalLoss, self).__init__() 53 | self.alpha = alpha 54 | self.gamma = gamma 55 | self.reduction = reduction 56 | self.loss_weight = loss_weight 57 | 58 | def forward(self, 59 | pred, 60 | target, 61 | weight=None, 62 | avg_factor=None, 63 | reduction_override=None): 64 | """Forward function. 65 | 66 | Args: 67 | pred (torch.Tensor): The prediction. 68 | target (torch.Tensor): The learning target of the prediction 69 | in gaussian distribution. 70 | weight (torch.Tensor, optional): The weight of loss for each 71 | prediction. Defaults to None. 72 | avg_factor (int, optional): Average factor that is used to average 73 | the loss. Defaults to None. 74 | reduction_override (str, optional): The reduction method used to 75 | override the original reduction method of the loss. 76 | Defaults to None. 77 | """ 78 | assert reduction_override in (None, 'none', 'mean', 'sum') 79 | reduction = ( 80 | reduction_override if reduction_override else self.reduction) 81 | loss_reg = self.loss_weight * gaussian_focal_loss( 82 | pred, 83 | target, 84 | weight, 85 | alpha=self.alpha, 86 | gamma=self.gamma, 87 | reduction=reduction, 88 | avg_factor=avg_factor) 89 | return loss_reg 90 | -------------------------------------------------------------------------------- /mmdet/models/losses/mse_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..builder import LOSSES 5 | from .utils import weighted_loss 6 | 7 | 8 | @weighted_loss 9 | def mse_loss(pred, target): 10 | """Warpper of mse loss.""" 11 | return F.mse_loss(pred, target, reduction='none') 12 | 13 | 14 | @LOSSES.register_module() 15 | class MSELoss(nn.Module): 16 | """MSELoss. 17 | 18 | Args: 19 | reduction (str, optional): The method that reduces the loss to a 20 | scalar. Options are "none", "mean" and "sum". 21 | loss_weight (float, optional): The weight of the loss. Defaults to 1.0 22 | """ 23 | 24 | def __init__(self, reduction='mean', loss_weight=1.0): 25 | super().__init__() 26 | self.reduction = reduction 27 | self.loss_weight = loss_weight 28 | 29 | def forward(self, pred, target, weight=None, avg_factor=None): 30 | """Forward function of loss. 31 | 32 | Args: 33 | pred (torch.Tensor): The prediction. 34 | target (torch.Tensor): The learning target of the prediction. 35 | weight (torch.Tensor, optional): Weight of the loss for each 36 | prediction. Defaults to None. 37 | avg_factor (int, optional): Average factor that is used to average 38 | the loss. Defaults to None. 39 | 40 | Returns: 41 | torch.Tensor: The calculated loss 42 | """ 43 | loss = self.loss_weight * mse_loss( 44 | pred, 45 | target, 46 | weight, 47 | reduction=self.reduction, 48 | avg_factor=avg_factor) 49 | return loss 50 | -------------------------------------------------------------------------------- /mmdet/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn.functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are "none", "mean" and "sum". 12 | 13 | Return: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | elif reduction_enum == 2: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. 32 | reduction (str): Same as built-in losses of PyTorch. 33 | avg_factor (float): Avarage factor when computing the mean of losses. 34 | 35 | Returns: 36 | Tensor: Processed loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | loss = loss * weight 41 | 42 | # if avg_factor is not specified, just reduce the loss 43 | if avg_factor is None: 44 | loss = reduce_loss(loss, reduction) 45 | else: 46 | # if reduction is mean, then average the loss by avg_factor 47 | if reduction == 'mean': 48 | loss = loss.sum() / avg_factor 49 | # if reduction is 'none', then do nothing, otherwise raise an error 50 | elif reduction != 'none': 51 | raise ValueError('avg_factor can not be used with reduction="sum"') 52 | return loss 53 | 54 | 55 | def weighted_loss(loss_func): 56 | """Create a weighted version of a given loss function. 57 | 58 | To use this decorator, the loss function must have the signature like 59 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 60 | element-wise loss without any reduction. This decorator will add weight 61 | and reduction arguments to the function. The decorated function will have 62 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 63 | avg_factor=None, **kwargs)`. 64 | 65 | :Example: 66 | 67 | >>> import torch 68 | >>> @weighted_loss 69 | >>> def l1_loss(pred, target): 70 | >>> return (pred - target).abs() 71 | 72 | >>> pred = torch.Tensor([0, 2, 3]) 73 | >>> target = torch.Tensor([1, 1, 1]) 74 | >>> weight = torch.Tensor([1, 0, 1]) 75 | 76 | >>> l1_loss(pred, target) 77 | tensor(1.3333) 78 | >>> l1_loss(pred, target, weight) 79 | tensor(1.) 80 | >>> l1_loss(pred, target, reduction='none') 81 | tensor([1., 1., 2.]) 82 | >>> l1_loss(pred, target, weight, avg_factor=2) 83 | tensor(1.5000) 84 | """ 85 | 86 | @functools.wraps(loss_func) 87 | def wrapper(pred, 88 | target, 89 | weight=None, 90 | reduction='mean', 91 | avg_factor=None, 92 | **kwargs): 93 | # get element-wise loss 94 | loss = loss_func(pred, target, **kwargs) 95 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 96 | return loss 97 | 98 | return wrapper 99 | -------------------------------------------------------------------------------- /mmdet/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .bfp import BFP 2 | from .channel_mapper import ChannelMapper 3 | from .fpn import FPN 4 | from .fpn_carafe import FPN_CARAFE 5 | from .hrfpn import HRFPN 6 | from .nas_fpn import NASFPN 7 | from .nasfcos_fpn import NASFCOS_FPN 8 | from .pafpn import PAFPN 9 | from .rfp import RFP 10 | from .yolo_neck import YOLOV3Neck 11 | 12 | __all__ = [ 13 | 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', 14 | 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck' 15 | ] 16 | -------------------------------------------------------------------------------- /mmdet/models/necks/channel_mapper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import ConvModule, xavier_init 3 | 4 | from ..builder import NECKS 5 | 6 | 7 | @NECKS.register_module() 8 | class ChannelMapper(nn.Module): 9 | r"""Channel Mapper to reduce/increase channels of backbone features. 10 | 11 | This is used to reduce/increase channels of backbone features. 12 | 13 | Args: 14 | in_channels (List[int]): Number of input channels per scale. 15 | out_channels (int): Number of output channels (used at each scale). 16 | kernel_size (int, optional): kernel_size for reducing channels (used 17 | at each scale). Default: 3. 18 | conv_cfg (dict, optional): Config dict for convolution layer. 19 | Default: None. 20 | norm_cfg (dict, optional): Config dict for normalization layer. 21 | Default: None. 22 | act_cfg (dict, optional): Config dict for activation layer in 23 | ConvModule. Default: dict(type='ReLU'). 24 | 25 | Example: 26 | >>> import torch 27 | >>> in_channels = [2, 3, 5, 7] 28 | >>> scales = [340, 170, 84, 43] 29 | >>> inputs = [torch.rand(1, c, s, s) 30 | ... for c, s in zip(in_channels, scales)] 31 | >>> self = ChannelMapper(in_channels, 11, 3).eval() 32 | >>> outputs = self.forward(inputs) 33 | >>> for i in range(len(outputs)): 34 | ... print(f'outputs[{i}].shape = {outputs[i].shape}') 35 | outputs[0].shape = torch.Size([1, 11, 340, 340]) 36 | outputs[1].shape = torch.Size([1, 11, 170, 170]) 37 | outputs[2].shape = torch.Size([1, 11, 84, 84]) 38 | outputs[3].shape = torch.Size([1, 11, 43, 43]) 39 | """ 40 | 41 | def __init__(self, 42 | in_channels, 43 | out_channels, 44 | kernel_size=3, 45 | conv_cfg=None, 46 | norm_cfg=None, 47 | act_cfg=dict(type='ReLU')): 48 | super(ChannelMapper, self).__init__() 49 | assert isinstance(in_channels, list) 50 | 51 | self.convs = nn.ModuleList() 52 | for in_channel in in_channels: 53 | self.convs.append( 54 | ConvModule( 55 | in_channel, 56 | out_channels, 57 | kernel_size, 58 | padding=(kernel_size - 1) // 2, 59 | conv_cfg=conv_cfg, 60 | norm_cfg=norm_cfg, 61 | act_cfg=act_cfg)) 62 | 63 | # default init_weights for conv(msra) and norm in ConvModule 64 | def init_weights(self): 65 | """Initialize the weights of ChannelMapper module.""" 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | xavier_init(m, distribution='uniform') 69 | 70 | def forward(self, inputs): 71 | """Forward function.""" 72 | assert len(inputs) == len(self.convs) 73 | outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] 74 | return tuple(outs) 75 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_roi_head import BaseRoIHead 2 | from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead, 3 | Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) 4 | from .cascade_roi_head import CascadeRoIHead 5 | from .double_roi_head import DoubleHeadRoIHead 6 | from .dynamic_roi_head import DynamicRoIHead 7 | from .grid_roi_head import GridRoIHead 8 | from .htc_roi_head import HybridTaskCascadeRoIHead 9 | from .mask_heads import (CoarseMaskHead, FCNMaskHead, FusedSemanticHead, 10 | GridHead, HTCMaskHead, MaskIoUHead, MaskPointHead) 11 | from .mask_scoring_roi_head import MaskScoringRoIHead 12 | from .pisa_roi_head import PISARoIHead 13 | from .point_rend_roi_head import PointRendRoIHead 14 | from .roi_extractors import SingleRoIExtractor 15 | from .shared_heads import ResLayer 16 | from .standard_roi_head import StandardRoIHead 17 | 18 | from .wsddn_roi_head import WSDDNRoIHead 19 | from .oicr_roi_head import OICRRoIHead 20 | from .wsod2_roi_head import WSOD2RoIHead 21 | 22 | __all__ = [ 23 | 'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead', 24 | 'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead', 25 | 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'StandardRoIHead', 26 | 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'FCNMaskHead', 27 | 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 'MaskIoUHead', 28 | 'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead', 29 | 'CoarseMaskHead', 'DynamicRoIHead', 30 | 'WSDDNRoIHead', 'OICRRoIHead', 'WSOD2RoIHead' 31 | ] 32 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/base_roi_head.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch.nn as nn 4 | 5 | from ..builder import build_shared_head 6 | 7 | 8 | class BaseRoIHead(nn.Module, metaclass=ABCMeta): 9 | """Base class for RoIHeads.""" 10 | 11 | def __init__(self, 12 | bbox_roi_extractor=None, 13 | bbox_head=None, 14 | mask_roi_extractor=None, 15 | mask_head=None, 16 | shared_head=None, 17 | train_cfg=None, 18 | test_cfg=None): 19 | super(BaseRoIHead, self).__init__() 20 | self.train_cfg = train_cfg 21 | self.test_cfg = test_cfg 22 | if shared_head is not None: 23 | self.shared_head = build_shared_head(shared_head) 24 | 25 | if bbox_head is not None: 26 | self.init_bbox_head(bbox_roi_extractor, bbox_head) 27 | 28 | if mask_head is not None: 29 | self.init_mask_head(mask_roi_extractor, mask_head) 30 | 31 | self.init_assigner_sampler() 32 | 33 | @property 34 | def with_bbox(self): 35 | """bool: whether the RoI head contains a `bbox_head`""" 36 | return hasattr(self, 'bbox_head') and self.bbox_head is not None 37 | 38 | @property 39 | def with_mask(self): 40 | """bool: whether the RoI head contains a `mask_head`""" 41 | return hasattr(self, 'mask_head') and self.mask_head is not None 42 | 43 | @property 44 | def with_shared_head(self): 45 | """bool: whether the RoI head contains a `shared_head`""" 46 | return hasattr(self, 'shared_head') and self.shared_head is not None 47 | 48 | @abstractmethod 49 | def init_weights(self, pretrained): 50 | """Initialize the weights in head. 51 | 52 | Args: 53 | pretrained (str, optional): Path to pre-trained weights. 54 | Defaults to None. 55 | """ 56 | pass 57 | 58 | @abstractmethod 59 | def init_bbox_head(self): 60 | """Initialize ``bbox_head``""" 61 | pass 62 | 63 | @abstractmethod 64 | def init_mask_head(self): 65 | """Initialize ``mask_head``""" 66 | pass 67 | 68 | @abstractmethod 69 | def init_assigner_sampler(self): 70 | """Initialize assigner and sampler.""" 71 | pass 72 | 73 | @abstractmethod 74 | def forward_train(self, 75 | x, 76 | img_meta, 77 | proposal_list, 78 | gt_bboxes, 79 | gt_labels, 80 | gt_bboxes_ignore=None, 81 | gt_masks=None, 82 | **kwargs): 83 | """Forward function during training.""" 84 | pass 85 | 86 | async def async_simple_test(self, x, img_meta, **kwargs): 87 | """Asynchronized test function.""" 88 | raise NotImplementedError 89 | 90 | def simple_test(self, 91 | x, 92 | proposal_list, 93 | img_meta, 94 | proposals=None, 95 | rescale=False, 96 | **kwargs): 97 | """Test without augmentation.""" 98 | pass 99 | 100 | def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs): 101 | """Test with augmentations. 102 | 103 | If rescale is False, then returned bboxes and masks will fit the scale 104 | of imgs[0]. 105 | """ 106 | pass 107 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/bbox_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_head import BBoxHead 2 | from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead, 3 | Shared4Conv1FCBBoxHead) 4 | from .double_bbox_head import DoubleConvFCBBoxHead 5 | from .sabl_head import SABLHead 6 | 7 | from .wsddn_head import WSDDNHead 8 | from .oicr_head import OICRHead 9 | 10 | __all__ = [ 11 | 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 12 | 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead', 13 | 'WSDDNHead', 'OICRHead' 14 | ] 15 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/double_roi_head.py: -------------------------------------------------------------------------------- 1 | from ..builder import HEADS 2 | from .standard_roi_head import StandardRoIHead 3 | 4 | 5 | @HEADS.register_module() 6 | class DoubleHeadRoIHead(StandardRoIHead): 7 | """RoI head for Double Head RCNN. 8 | 9 | https://arxiv.org/abs/1904.06493 10 | """ 11 | 12 | def __init__(self, reg_roi_scale_factor, **kwargs): 13 | super(DoubleHeadRoIHead, self).__init__(**kwargs) 14 | self.reg_roi_scale_factor = reg_roi_scale_factor 15 | 16 | def _bbox_forward(self, x, rois): 17 | """Box head forward function used in both training and testing time.""" 18 | bbox_cls_feats = self.bbox_roi_extractor( 19 | x[:self.bbox_roi_extractor.num_inputs], rois) 20 | bbox_reg_feats = self.bbox_roi_extractor( 21 | x[:self.bbox_roi_extractor.num_inputs], 22 | rois, 23 | roi_scale_factor=self.reg_roi_scale_factor) 24 | if self.with_shared_head: 25 | bbox_cls_feats = self.shared_head(bbox_cls_feats) 26 | bbox_reg_feats = self.shared_head(bbox_reg_feats) 27 | cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) 28 | 29 | bbox_results = dict( 30 | cls_score=cls_score, 31 | bbox_pred=bbox_pred, 32 | bbox_feats=bbox_cls_feats) 33 | return bbox_results 34 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/mask_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .coarse_mask_head import CoarseMaskHead 2 | from .fcn_mask_head import FCNMaskHead 3 | from .fused_semantic_head import FusedSemanticHead 4 | from .grid_head import GridHead 5 | from .htc_mask_head import HTCMaskHead 6 | from .mask_point_head import MaskPointHead 7 | from .maskiou_head import MaskIoUHead 8 | 9 | __all__ = [ 10 | 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 11 | 'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead' 12 | ] 13 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/mask_heads/coarse_mask_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import ConvModule, Linear, constant_init, xavier_init 3 | from mmcv.runner import auto_fp16 4 | 5 | from mmdet.models.builder import HEADS 6 | from .fcn_mask_head import FCNMaskHead 7 | 8 | 9 | @HEADS.register_module() 10 | class CoarseMaskHead(FCNMaskHead): 11 | """Coarse mask head used in PointRend. 12 | 13 | Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample 14 | the input feature map instead of upsample it. 15 | 16 | Args: 17 | num_convs (int): Number of conv layers in the head. Default: 0. 18 | num_fcs (int): Number of fc layers in the head. Default: 2. 19 | fc_out_channels (int): Number of output channels of fc layer. 20 | Default: 1024. 21 | downsample_factor (int): The factor that feature map is downsampled by. 22 | Default: 2. 23 | """ 24 | 25 | def __init__(self, 26 | num_convs=0, 27 | num_fcs=2, 28 | fc_out_channels=1024, 29 | downsample_factor=2, 30 | *arg, 31 | **kwarg): 32 | super(CoarseMaskHead, self).__init__( 33 | *arg, num_convs=num_convs, upsample_cfg=dict(type=None), **kwarg) 34 | self.num_fcs = num_fcs 35 | assert self.num_fcs > 0 36 | self.fc_out_channels = fc_out_channels 37 | self.downsample_factor = downsample_factor 38 | assert self.downsample_factor >= 1 39 | # remove conv_logit 40 | delattr(self, 'conv_logits') 41 | 42 | if downsample_factor > 1: 43 | downsample_in_channels = ( 44 | self.conv_out_channels 45 | if self.num_convs > 0 else self.in_channels) 46 | self.downsample_conv = ConvModule( 47 | downsample_in_channels, 48 | self.conv_out_channels, 49 | kernel_size=downsample_factor, 50 | stride=downsample_factor, 51 | padding=0, 52 | conv_cfg=self.conv_cfg, 53 | norm_cfg=self.norm_cfg) 54 | else: 55 | self.downsample_conv = None 56 | 57 | self.output_size = (self.roi_feat_size[0] // downsample_factor, 58 | self.roi_feat_size[1] // downsample_factor) 59 | self.output_area = self.output_size[0] * self.output_size[1] 60 | 61 | last_layer_dim = self.conv_out_channels * self.output_area 62 | 63 | self.fcs = nn.ModuleList() 64 | for i in range(num_fcs): 65 | fc_in_channels = ( 66 | last_layer_dim if i == 0 else self.fc_out_channels) 67 | self.fcs.append(Linear(fc_in_channels, self.fc_out_channels)) 68 | last_layer_dim = self.fc_out_channels 69 | output_channels = self.num_classes * self.output_area 70 | self.fc_logits = Linear(last_layer_dim, output_channels) 71 | 72 | def init_weights(self): 73 | for m in self.fcs.modules(): 74 | if isinstance(m, nn.Linear): 75 | xavier_init(m) 76 | constant_init(self.fc_logits, 0.001) 77 | 78 | @auto_fp16() 79 | def forward(self, x): 80 | for conv in self.convs: 81 | x = conv(x) 82 | 83 | if self.downsample_conv is not None: 84 | x = self.downsample_conv(x) 85 | 86 | x = x.flatten(1) 87 | for fc in self.fcs: 88 | x = self.relu(fc(x)) 89 | mask_pred = self.fc_logits(x).view( 90 | x.size(0), self.num_classes, *self.output_size) 91 | return mask_pred 92 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/mask_heads/htc_mask_head.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn import ConvModule 2 | 3 | from mmdet.models.builder import HEADS 4 | from .fcn_mask_head import FCNMaskHead 5 | 6 | 7 | @HEADS.register_module() 8 | class HTCMaskHead(FCNMaskHead): 9 | 10 | def __init__(self, with_conv_res=True, *args, **kwargs): 11 | super(HTCMaskHead, self).__init__(*args, **kwargs) 12 | self.with_conv_res = with_conv_res 13 | if self.with_conv_res: 14 | self.conv_res = ConvModule( 15 | self.conv_out_channels, 16 | self.conv_out_channels, 17 | 1, 18 | conv_cfg=self.conv_cfg, 19 | norm_cfg=self.norm_cfg) 20 | 21 | def init_weights(self): 22 | super(HTCMaskHead, self).init_weights() 23 | if self.with_conv_res: 24 | self.conv_res.init_weights() 25 | 26 | def forward(self, x, res_feat=None, return_logits=True, return_feat=True): 27 | if res_feat is not None: 28 | assert self.with_conv_res 29 | res_feat = self.conv_res(res_feat) 30 | x = x + res_feat 31 | for conv in self.convs: 32 | x = conv(x) 33 | res_feat = x 34 | outs = [] 35 | if return_logits: 36 | x = self.upsample(x) 37 | if self.upsample_method == 'deconv': 38 | x = self.relu(x) 39 | mask_pred = self.conv_logits(x) 40 | outs.append(mask_pred) 41 | if return_feat: 42 | outs.append(res_feat) 43 | return outs if len(outs) > 1 else outs[0] 44 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/roi_extractors/__init__.py: -------------------------------------------------------------------------------- 1 | from .generic_roi_extractor import GenericRoIExtractor 2 | from .single_level_roi_extractor import SingleRoIExtractor 3 | 4 | __all__ = [ 5 | 'SingleRoIExtractor', 6 | 'GenericRoIExtractor', 7 | ] 8 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv import ops 6 | 7 | 8 | class BaseRoIExtractor(nn.Module, metaclass=ABCMeta): 9 | """Base class for RoI extractor. 10 | 11 | Args: 12 | roi_layer (dict): Specify RoI layer type and arguments. 13 | out_channels (int): Output channels of RoI layers. 14 | featmap_strides (int): Strides of input feature maps. 15 | """ 16 | 17 | def __init__(self, roi_layer, out_channels, featmap_strides): 18 | super(BaseRoIExtractor, self).__init__() 19 | self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) 20 | self.out_channels = out_channels 21 | self.featmap_strides = featmap_strides 22 | self.fp16_enabled = False 23 | 24 | @property 25 | def num_inputs(self): 26 | """int: Number of input feature maps.""" 27 | return len(self.featmap_strides) 28 | 29 | def init_weights(self): 30 | pass 31 | 32 | def build_roi_layers(self, layer_cfg, featmap_strides): 33 | """Build RoI operator to extract feature from each level feature map. 34 | 35 | Args: 36 | layer_cfg (dict): Dictionary to construct and config RoI layer 37 | operation. Options are modules under ``mmcv/ops`` such as 38 | ``RoIAlign``. 39 | featmap_strides (int): The stride of input feature map w.r.t to the 40 | original image size, which would be used to scale RoI 41 | coordinate (original image coordinate system) to feature 42 | coordinate system. 43 | 44 | Returns: 45 | nn.ModuleList: The RoI extractor modules for each level feature 46 | map. 47 | """ 48 | 49 | cfg = layer_cfg.copy() 50 | layer_type = cfg.pop('type') 51 | assert hasattr(ops, layer_type) 52 | layer_cls = getattr(ops, layer_type) 53 | roi_layers = nn.ModuleList( 54 | [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides]) 55 | return roi_layers 56 | 57 | def roi_rescale(self, rois, scale_factor): 58 | """Scale RoI coordinates by scale factor. 59 | 60 | Args: 61 | rois (torch.Tensor): RoI (Region of Interest), shape (n, 5) 62 | scale_factor (float): Scale factor that RoI will be multiplied by. 63 | 64 | Returns: 65 | torch.Tensor: Scaled RoI. 66 | """ 67 | 68 | cx = (rois[:, 1] + rois[:, 3]) * 0.5 69 | cy = (rois[:, 2] + rois[:, 4]) * 0.5 70 | w = rois[:, 3] - rois[:, 1] 71 | h = rois[:, 4] - rois[:, 2] 72 | new_w = w * scale_factor 73 | new_h = h * scale_factor 74 | x1 = cx - new_w * 0.5 75 | x2 = cx + new_w * 0.5 76 | y1 = cy - new_h * 0.5 77 | y2 = cy + new_h * 0.5 78 | new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) 79 | return new_rois 80 | 81 | @abstractmethod 82 | def forward(self, feats, rois, roi_scale_factor=None): 83 | pass 84 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn.bricks import build_plugin_layer 2 | from mmcv.runner import force_fp32 3 | 4 | from mmdet.models.builder import ROI_EXTRACTORS 5 | from .base_roi_extractor import BaseRoIExtractor 6 | 7 | 8 | @ROI_EXTRACTORS.register_module() 9 | class GenericRoIExtractor(BaseRoIExtractor): 10 | """Extract RoI features from all level feature maps levels. 11 | 12 | This is the implementation of `A novel Region of Interest Extraction Layer 13 | for Instance Segmentation `_. 14 | 15 | Args: 16 | aggregation (str): The method to aggregate multiple feature maps. 17 | Options are 'sum', 'concat'. Default: 'sum'. 18 | pre_cfg (dict | None): Specify pre-processing modules. Default: None. 19 | post_cfg (dict | None): Specify post-processing modules. Default: None. 20 | kwargs (keyword arguments): Arguments that are the same 21 | as :class:`BaseRoIExtractor`. 22 | """ 23 | 24 | def __init__(self, 25 | aggregation='sum', 26 | pre_cfg=None, 27 | post_cfg=None, 28 | **kwargs): 29 | super(GenericRoIExtractor, self).__init__(**kwargs) 30 | 31 | assert aggregation in ['sum', 'concat'] 32 | 33 | self.aggregation = aggregation 34 | self.with_post = post_cfg is not None 35 | self.with_pre = pre_cfg is not None 36 | # build pre/post processing modules 37 | if self.with_post: 38 | self.post_module = build_plugin_layer(post_cfg, '_post_module')[1] 39 | if self.with_pre: 40 | self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1] 41 | 42 | @force_fp32(apply_to=('feats', ), out_fp16=True) 43 | def forward(self, feats, rois, roi_scale_factor=None): 44 | """Forward function.""" 45 | if len(feats) == 1: 46 | return self.roi_layers[0](feats[0], rois) 47 | 48 | out_size = self.roi_layers[0].output_size 49 | num_levels = len(feats) 50 | roi_feats = feats[0].new_zeros( 51 | rois.size(0), self.out_channels, *out_size) 52 | 53 | # some times rois is an empty tensor 54 | if roi_feats.shape[0] == 0: 55 | return roi_feats 56 | 57 | if roi_scale_factor is not None: 58 | rois = self.roi_rescale(rois, roi_scale_factor) 59 | 60 | # mark the starting channels for concat mode 61 | start_channels = 0 62 | for i in range(num_levels): 63 | roi_feats_t = self.roi_layers[i](feats[i], rois) 64 | end_channels = start_channels + roi_feats_t.size(1) 65 | if self.with_pre: 66 | # apply pre-processing to a RoI extracted from each layer 67 | roi_feats_t = self.pre_module(roi_feats_t) 68 | if self.aggregation == 'sum': 69 | # and sum them all 70 | roi_feats += roi_feats_t 71 | else: 72 | # and concat them along channel dimension 73 | roi_feats[:, start_channels:end_channels] = roi_feats_t 74 | # update channels starting position 75 | start_channels = end_channels 76 | # check if concat channels match at the end 77 | if self.aggregation == 'concat': 78 | assert start_channels == self.out_channels 79 | 80 | if self.with_post: 81 | # apply post-processing before return the result 82 | roi_feats = self.post_module(roi_feats) 83 | return roi_feats 84 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/shared_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .res_layer import ResLayer 2 | 3 | __all__ = ['ResLayer'] 4 | -------------------------------------------------------------------------------- /mmdet/models/roi_heads/shared_heads/res_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import constant_init, kaiming_init 3 | from mmcv.runner import auto_fp16, load_checkpoint 4 | 5 | from mmdet.models.backbones import ResNet 6 | from mmdet.models.builder import SHARED_HEADS 7 | from mmdet.models.utils import ResLayer as _ResLayer 8 | from mmdet.utils import get_root_logger 9 | 10 | 11 | @SHARED_HEADS.register_module() 12 | class ResLayer(nn.Module): 13 | 14 | def __init__(self, 15 | depth, 16 | stage=3, 17 | stride=2, 18 | dilation=1, 19 | style='pytorch', 20 | norm_cfg=dict(type='BN', requires_grad=True), 21 | norm_eval=True, 22 | with_cp=False, 23 | dcn=None): 24 | super(ResLayer, self).__init__() 25 | self.norm_eval = norm_eval 26 | self.norm_cfg = norm_cfg 27 | self.stage = stage 28 | self.fp16_enabled = False 29 | block, stage_blocks = ResNet.arch_settings[depth] 30 | stage_block = stage_blocks[stage] 31 | planes = 64 * 2**stage 32 | inplanes = 64 * 2**(stage - 1) * block.expansion 33 | 34 | res_layer = _ResLayer( 35 | block, 36 | inplanes, 37 | planes, 38 | stage_block, 39 | stride=stride, 40 | dilation=dilation, 41 | style=style, 42 | with_cp=with_cp, 43 | norm_cfg=self.norm_cfg, 44 | dcn=dcn) 45 | self.add_module(f'layer{stage + 1}', res_layer) 46 | 47 | def init_weights(self, pretrained=None): 48 | """Initialize the weights in the module. 49 | 50 | Args: 51 | pretrained (str, optional): Path to pre-trained weights. 52 | Defaults to None. 53 | """ 54 | if isinstance(pretrained, str): 55 | logger = get_root_logger() 56 | load_checkpoint(self, pretrained, strict=False, logger=logger) 57 | elif pretrained is None: 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | kaiming_init(m) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | constant_init(m, 1) 63 | else: 64 | raise TypeError('pretrained must be a str or None') 65 | 66 | @auto_fp16() 67 | def forward(self, x): 68 | res_layer = getattr(self, f'layer{self.stage + 1}') 69 | out = res_layer(x) 70 | return out 71 | 72 | def train(self, mode=True): 73 | super(ResLayer, self).train(mode) 74 | if self.norm_eval: 75 | for m in self.modules(): 76 | if isinstance(m, nn.BatchNorm2d): 77 | m.eval() 78 | -------------------------------------------------------------------------------- /mmdet/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_target import gaussian_radius, gen_gaussian_target 2 | from .res_layer import ResLayer 3 | 4 | __all__ = ['ResLayer', 'gaussian_radius', 'gen_gaussian_target'] 5 | -------------------------------------------------------------------------------- /mmdet/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # This file is added for back-compatibility. Thus, downstream codebase 2 | # could still use and import mmdet.ops. 3 | 4 | # yapf: disable 5 | from mmcv.ops import (ContextBlock, Conv2d, ConvTranspose2d, ConvWS2d, 6 | CornerPool, DeformConv, DeformConvPack, DeformRoIPooling, 7 | DeformRoIPoolingPack, GeneralizedAttention, Linear, 8 | MaskedConv2d, MaxPool2d, ModulatedDeformConv, 9 | ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack, 10 | NonLocal2D, RoIAlign, RoIPool, SAConv2d, 11 | SigmoidFocalLoss, SimpleRoIAlign, batched_nms, 12 | build_plugin_layer, conv_ws_2d, deform_conv, 13 | deform_roi_pooling, get_compiler_version, 14 | get_compiling_cuda_version, modulated_deform_conv, nms, 15 | nms_match, point_sample, rel_roi_point_to_rel_img_point, 16 | roi_align, roi_pool, sigmoid_focal_loss, soft_nms) 17 | 18 | # yapf: enable 19 | 20 | __all__ = [ 21 | 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 22 | 'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 23 | 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', 24 | 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', 25 | 'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss', 26 | 'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D', 27 | 'get_compiler_version', 'get_compiling_cuda_version', 'ConvWS2d', 28 | 'conv_ws_2d', 'build_plugin_layer', 'batched_nms', 'Conv2d', 29 | 'ConvTranspose2d', 'MaxPool2d', 'Linear', 'nms_match', 'CornerPool', 30 | 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 31 | 'SAConv2d' 32 | ] 33 | -------------------------------------------------------------------------------- /mmdet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_env import collect_env 2 | from .logger import get_root_logger 3 | 4 | __all__ = ['get_root_logger', 'collect_env'] 5 | -------------------------------------------------------------------------------- /mmdet/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import collect_env as collect_base_env 2 | from mmcv.utils import get_git_hash 3 | 4 | import mmdet 5 | 6 | 7 | def collect_env(): 8 | """Collect the information of the running environments.""" 9 | env_info = collect_base_env() 10 | env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7] 11 | return env_info 12 | 13 | 14 | if __name__ == '__main__': 15 | for name, val in collect_env().items(): 16 | print(f'{name}: {val}') 17 | -------------------------------------------------------------------------------- /mmdet/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | """Get root logger. 8 | 9 | Args: 10 | log_file (str, optional): File path of log. Defaults to None. 11 | log_level (int, optional): The level of logger. 12 | Defaults to logging.INFO. 13 | 14 | Returns: 15 | :obj:`logging.Logger`: The obtained logger 16 | """ 17 | logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level) 18 | 19 | return logger 20 | -------------------------------------------------------------------------------- /mmdet/utils/profiling.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import sys 3 | import time 4 | 5 | import torch 6 | 7 | if sys.version_info >= (3, 7): 8 | 9 | @contextlib.contextmanager 10 | def profile_time(trace_name, 11 | name, 12 | enabled=True, 13 | stream=None, 14 | end_stream=None): 15 | """Print time spent by CPU and GPU. 16 | 17 | Useful as a temporary context manager to find sweet spots of code 18 | suitable for async implementation. 19 | """ 20 | if (not enabled) or not torch.cuda.is_available(): 21 | yield 22 | return 23 | stream = stream if stream else torch.cuda.current_stream() 24 | end_stream = end_stream if end_stream else stream 25 | start = torch.cuda.Event(enable_timing=True) 26 | end = torch.cuda.Event(enable_timing=True) 27 | stream.record_event(start) 28 | try: 29 | cpu_start = time.monotonic() 30 | yield 31 | finally: 32 | cpu_end = time.monotonic() 33 | end_stream.record_event(end) 34 | end.synchronize() 35 | cpu_time = (cpu_end - cpu_start) * 1000 36 | gpu_time = start.elapsed_time(end) 37 | msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' 38 | msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' 39 | print(msg, end_stream) 40 | -------------------------------------------------------------------------------- /mmdet/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '2.6.0' 4 | short_version = __version__ 5 | 6 | 7 | def parse_version_info(version_str): 8 | version_info = [] 9 | for x in version_str.split('.'): 10 | if x.isdigit(): 11 | version_info.append(int(x)) 12 | elif x.find('rc') != -1: 13 | patch_version = x.split('rc') 14 | version_info.append(int(patch_version[0])) 15 | version_info.append(f'rc{patch_version[1]}') 16 | return tuple(version_info) 17 | 18 | 19 | version_info = parse_version_info(__version__) 20 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --xdoctest --xdoctest-style=auto 3 | norecursedirs = .git ignore build __pycache__ data docker docs .eggs 4 | 5 | filterwarnings= default 6 | ignore:.*No cfgstr given in Cacher constructor or call.*:Warning 7 | ignore:.*Define the __nice__ method for.*:Warning 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/build.txt 2 | -r requirements/optional.txt 3 | -r requirements/runtime.txt 4 | -r requirements/tests.txt 5 | -------------------------------------------------------------------------------- /requirements/build.txt: -------------------------------------------------------------------------------- 1 | # These must be installed before building mmdetection 2 | cython 3 | numpy 4 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | recommonmark 2 | sphinx 3 | sphinx_markdown_tables 4 | sphinx_rtd_theme 5 | -------------------------------------------------------------------------------- /requirements/optional.txt: -------------------------------------------------------------------------------- 1 | albumentations>=0.3.2 2 | cityscapesscripts 3 | imagecorruptions 4 | mmlvis 5 | -------------------------------------------------------------------------------- /requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | mmpycocotools 3 | numpy 4 | six 5 | terminaltables 6 | -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | asynctest 2 | codecov 3 | flake8 4 | interrogate 5 | isort==4.3.21 6 | # Note: used for kwarray.group_items, this may be ported to mmcv in the future. 7 | kwarray 8 | pytest 9 | ubelt 10 | xdoctest>=0.10.0 11 | yapf 12 | -------------------------------------------------------------------------------- /resources/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/resources/architecture.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 79 3 | multi_line_output = 0 4 | known_standard_library = setuptools 5 | known_first_party = mmdet 6 | known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch 7 | no_lines_before = STDLIB,LOCALFOLDER 8 | default_section = THIRDPARTY 9 | 10 | [yapf] 11 | BASED_ON_STYLE = pep8 12 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 13 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 14 | -------------------------------------------------------------------------------- /tests/async_benchmark.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import shutil 4 | import urllib 5 | 6 | import mmcv 7 | import torch 8 | 9 | from mmdet.apis import (async_inference_detector, inference_detector, 10 | init_detector, show_result) 11 | from mmdet.utils.contextmanagers import concurrent 12 | from mmdet.utils.profiling import profile_time 13 | 14 | 15 | async def main(): 16 | """Benchmark between async and synchronous inference interfaces. 17 | 18 | Sample runs for 20 demo images on K80 GPU, model - mask_rcnn_r50_fpn_1x: 19 | 20 | async sync 21 | 22 | 7981.79 ms 9660.82 ms 23 | 8074.52 ms 9660.94 ms 24 | 7976.44 ms 9406.83 ms 25 | 26 | Async variant takes about 0.83-0.85 of the time of the synchronous 27 | interface. 28 | """ 29 | project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 30 | 31 | config_file = os.path.join(project_dir, 32 | 'configs/mask_rcnn_r50_fpn_1x_coco.py') 33 | checkpoint_file = os.path.join( 34 | project_dir, 'checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth') 35 | 36 | if not os.path.exists(checkpoint_file): 37 | url = ('https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection' 38 | '/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth') 39 | print(f'Downloading {url} ...') 40 | local_filename, _ = urllib.request.urlretrieve(url) 41 | os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True) 42 | shutil.move(local_filename, checkpoint_file) 43 | print(f'Saved as {checkpoint_file}') 44 | else: 45 | print(f'Using existing checkpoint {checkpoint_file}') 46 | 47 | device = 'cuda:0' 48 | model = init_detector( 49 | config_file, checkpoint=checkpoint_file, device=device) 50 | 51 | # queue is used for concurrent inference of multiple images 52 | streamqueue = asyncio.Queue() 53 | # queue size defines concurrency level 54 | streamqueue_size = 4 55 | 56 | for _ in range(streamqueue_size): 57 | streamqueue.put_nowait(torch.cuda.Stream(device=device)) 58 | 59 | # test a single image and show the results 60 | img = mmcv.imread(os.path.join(project_dir, 'demo/demo.jpg')) 61 | 62 | # warmup 63 | await async_inference_detector(model, img) 64 | 65 | async def detect(img): 66 | async with concurrent(streamqueue): 67 | return await async_inference_detector(model, img) 68 | 69 | num_of_images = 20 70 | with profile_time('benchmark', 'async'): 71 | tasks = [ 72 | asyncio.create_task(detect(img)) for _ in range(num_of_images) 73 | ] 74 | async_results = await asyncio.gather(*tasks) 75 | 76 | with torch.cuda.stream(torch.cuda.default_stream()): 77 | with profile_time('benchmark', 'sync'): 78 | sync_results = [ 79 | inference_detector(model, img) for _ in range(num_of_images) 80 | ] 81 | 82 | result_dir = os.path.join(project_dir, 'demo') 83 | show_result( 84 | img, 85 | async_results[0], 86 | model.CLASSES, 87 | score_thr=0.5, 88 | show=False, 89 | out_file=os.path.join(result_dir, 'result_async.jpg')) 90 | show_result( 91 | img, 92 | sync_results[0], 93 | model.CLASSES, 94 | score_thr=0.5, 95 | show=False, 96 | out_file=os.path.join(result_dir, 'result_sync.jpg')) 97 | 98 | 99 | if __name__ == '__main__': 100 | asyncio.run(main()) 101 | -------------------------------------------------------------------------------- /tests/data/coco_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "images": [ 3 | { 4 | "file_name": "fake1.jpg", 5 | "height": 800, 6 | "width": 800, 7 | "id": 0 8 | }, 9 | { 10 | "file_name": "fake2.jpg", 11 | "height": 800, 12 | "width": 800, 13 | "id": 1 14 | }, 15 | { 16 | "file_name": "fake3.jpg", 17 | "height": 800, 18 | "width": 800, 19 | "id": 2 20 | } 21 | ], 22 | "annotations": [ 23 | { 24 | "bbox": [ 25 | 0, 26 | 0, 27 | 20, 28 | 20 29 | ], 30 | "area": 400.00, 31 | "score": 1.0, 32 | "category_id": 1, 33 | "id": 1, 34 | "image_id": 0 35 | }, 36 | { 37 | "bbox": [ 38 | 0, 39 | 0, 40 | 20, 41 | 20 42 | ], 43 | "area": 400.00, 44 | "score": 1.0, 45 | "category_id": 2, 46 | "id": 2, 47 | "image_id": 0 48 | }, 49 | { 50 | "bbox": [ 51 | 0, 52 | 0, 53 | 20, 54 | 20 55 | ], 56 | "area": 400.00, 57 | "score": 1.0, 58 | "category_id": 1, 59 | "id": 3, 60 | "image_id": 1 61 | } 62 | ], 63 | "categories": [ 64 | { 65 | "id": 1, 66 | "name": "bus", 67 | "supercategory": "none" 68 | }, 69 | { 70 | "id": 2, 71 | "name": "car", 72 | "supercategory": "none" 73 | } 74 | ], 75 | "licenses": [], 76 | "info": null 77 | } 78 | -------------------------------------------------------------------------------- /tests/data/color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/tests/data/color.jpg -------------------------------------------------------------------------------- /tests/data/gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/tests/data/gray.jpg -------------------------------------------------------------------------------- /tests/test_async.py: -------------------------------------------------------------------------------- 1 | """Tests for async interface.""" 2 | 3 | import asyncio 4 | import os 5 | import sys 6 | 7 | import asynctest 8 | import mmcv 9 | import torch 10 | 11 | from mmdet.apis import async_inference_detector, init_detector 12 | 13 | if sys.version_info >= (3, 7): 14 | from mmdet.utils.contextmanagers import concurrent 15 | 16 | 17 | class AsyncTestCase(asynctest.TestCase): 18 | use_default_loop = False 19 | forbid_get_event_loop = True 20 | 21 | TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30')) 22 | 23 | def _run_test_method(self, method): 24 | result = method() 25 | if asyncio.iscoroutine(result): 26 | self.loop.run_until_complete( 27 | asyncio.wait_for(result, timeout=self.TEST_TIMEOUT)) 28 | 29 | 30 | class MaskRCNNDetector: 31 | 32 | def __init__(self, 33 | model_config, 34 | checkpoint=None, 35 | streamqueue_size=3, 36 | device='cuda:0'): 37 | 38 | self.streamqueue_size = streamqueue_size 39 | self.device = device 40 | # build the model and load checkpoint 41 | self.model = init_detector( 42 | model_config, checkpoint=None, device=self.device) 43 | self.streamqueue = None 44 | 45 | async def init(self): 46 | self.streamqueue = asyncio.Queue() 47 | for _ in range(self.streamqueue_size): 48 | stream = torch.cuda.Stream(device=self.device) 49 | self.streamqueue.put_nowait(stream) 50 | 51 | if sys.version_info >= (3, 7): 52 | 53 | async def apredict(self, img): 54 | if isinstance(img, str): 55 | img = mmcv.imread(img) 56 | async with concurrent(self.streamqueue): 57 | result = await async_inference_detector(self.model, img) 58 | return result 59 | 60 | 61 | class AsyncInferenceTestCase(AsyncTestCase): 62 | 63 | if sys.version_info >= (3, 7): 64 | 65 | async def test_simple_inference(self): 66 | if not torch.cuda.is_available(): 67 | import pytest 68 | 69 | pytest.skip('test requires GPU and torch+cuda') 70 | 71 | ori_grad_enabled = torch.is_grad_enabled() 72 | root_dir = os.path.dirname(os.path.dirname(__name__)) 73 | model_config = os.path.join( 74 | root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py') 75 | detector = MaskRCNNDetector(model_config) 76 | await detector.init() 77 | img_path = os.path.join(root_dir, 'demo/demo.jpg') 78 | bboxes, _ = await detector.apredict(img_path) 79 | self.assertTrue(bboxes) 80 | # asy inference detector will hack grad_enabled, 81 | # so restore here to avoid it to influence other tests 82 | torch.set_grad_enabled(ori_grad_enabled) 83 | -------------------------------------------------------------------------------- /tests/test_coder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmdet.core.bbox.coder import YOLOBBoxCoder 4 | 5 | 6 | def test_yolo_bbox_coder(): 7 | coder = YOLOBBoxCoder() 8 | bboxes = torch.Tensor([[-42., -29., 74., 61.], [-10., -29., 106., 61.], 9 | [22., -29., 138., 61.], [54., -29., 170., 61.]]) 10 | pred_bboxes = torch.Tensor([[0.4709, 0.6152, 0.1690, -0.4056], 11 | [0.5399, 0.6653, 0.1162, -0.4162], 12 | [0.4654, 0.6618, 0.1548, -0.4301], 13 | [0.4786, 0.6197, 0.1896, -0.4479]]) 14 | grid_size = 32 15 | expected_decode_bboxes = torch.Tensor( 16 | [[-53.6102, -10.3096, 83.7478, 49.6824], 17 | [-15.8700, -8.3901, 114.4236, 50.9693], 18 | [11.1822, -8.0924, 146.6034, 50.4476], 19 | [41.2068, -8.9232, 181.4236, 48.5840]]) 20 | assert expected_decode_bboxes.allclose( 21 | coder.decode(bboxes, pred_bboxes, grid_size)) 22 | -------------------------------------------------------------------------------- /tests/test_data/test_formatting.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from mmcv.utils import build_from_cfg 4 | 5 | from mmdet.datasets.builder import PIPELINES 6 | 7 | 8 | def test_default_format_bundle(): 9 | results = dict( 10 | img_prefix=osp.join(osp.dirname(__file__), '../data'), 11 | img_info=dict(filename='color.jpg')) 12 | load = dict(type='LoadImageFromFile') 13 | load = build_from_cfg(load, PIPELINES) 14 | bundle = dict(type='DefaultFormatBundle') 15 | bundle = build_from_cfg(bundle, PIPELINES) 16 | results = load(results) 17 | assert 'pad_shape' not in results 18 | assert 'scale_factor' not in results 19 | assert 'img_norm_cfg' not in results 20 | results = bundle(results) 21 | assert 'pad_shape' in results 22 | assert 'scale_factor' in results 23 | assert 'img_norm_cfg' in results 24 | -------------------------------------------------------------------------------- /tests/test_data/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mmdet.datasets import replace_ImageToTensor 4 | 5 | 6 | def test_replace_ImageToTensor(): 7 | # with MultiScaleFlipAug 8 | pipelines = [ 9 | dict(type='LoadImageFromFile'), 10 | dict( 11 | type='MultiScaleFlipAug', 12 | img_scale=(1333, 800), 13 | flip=False, 14 | transforms=[ 15 | dict(type='Resize', keep_ratio=True), 16 | dict(type='RandomFlip'), 17 | dict(type='Normalize'), 18 | dict(type='Pad', size_divisor=32), 19 | dict(type='ImageToTensor', keys=['img']), 20 | dict(type='Collect', keys=['img']), 21 | ]) 22 | ] 23 | expected_pipelines = [ 24 | dict(type='LoadImageFromFile'), 25 | dict( 26 | type='MultiScaleFlipAug', 27 | img_scale=(1333, 800), 28 | flip=False, 29 | transforms=[ 30 | dict(type='Resize', keep_ratio=True), 31 | dict(type='RandomFlip'), 32 | dict(type='Normalize'), 33 | dict(type='Pad', size_divisor=32), 34 | dict(type='DefaultFormatBundle'), 35 | dict(type='Collect', keys=['img']), 36 | ]) 37 | ] 38 | with pytest.warns(UserWarning): 39 | assert expected_pipelines == replace_ImageToTensor(pipelines) 40 | 41 | # without MultiScaleFlipAug 42 | pipelines = [ 43 | dict(type='LoadImageFromFile'), 44 | dict(type='Resize', keep_ratio=True), 45 | dict(type='RandomFlip'), 46 | dict(type='Normalize'), 47 | dict(type='Pad', size_divisor=32), 48 | dict(type='ImageToTensor', keys=['img']), 49 | dict(type='Collect', keys=['img']), 50 | ] 51 | expected_pipelines = [ 52 | dict(type='LoadImageFromFile'), 53 | dict(type='Resize', keep_ratio=True), 54 | dict(type='RandomFlip'), 55 | dict(type='Normalize'), 56 | dict(type='Pad', size_divisor=32), 57 | dict(type='DefaultFormatBundle'), 58 | dict(type='Collect', keys=['img']), 59 | ] 60 | with pytest.warns(UserWarning): 61 | assert expected_pipelines == replace_ImageToTensor(pipelines) 62 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | from mmdet import digit_version 2 | 3 | 4 | def test_version_check(): 5 | assert digit_version('1.0.5') > digit_version('1.0.5rc0') 6 | assert digit_version('1.0.5') > digit_version('1.0.4rc0') 7 | assert digit_version('1.0.5') > digit_version('1.0rc0') 8 | assert digit_version('1.0.0') > digit_version('0.6.2') 9 | assert digit_version('1.0.0') > digit_version('0.2.16') 10 | assert digit_version('1.0.5rc0') > digit_version('1.0.0rc0') 11 | assert digit_version('1.0.0rc1') > digit_version('1.0.0rc0') 12 | assert digit_version('1.0.0rc2') > digit_version('1.0.0rc0') 13 | assert digit_version('1.0.0rc2') > digit_version('1.0.0rc1') 14 | assert digit_version('1.0.1rc1') > digit_version('1.0.0rc1') 15 | assert digit_version('1.0.0') > digit_version('1.0.0rc1') 16 | -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch 5 | from mmcv import Config 6 | from mmcv.cnn import fuse_conv_bn 7 | from mmcv.parallel import MMDataParallel 8 | from mmcv.runner import load_checkpoint, wrap_fp16_model 9 | 10 | from mmdet.datasets import (build_dataloader, build_dataset, 11 | replace_ImageToTensor) 12 | from mmdet.models import build_detector 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='MMDet benchmark a model') 17 | parser.add_argument('config', help='test config file path') 18 | parser.add_argument('checkpoint', help='checkpoint file') 19 | parser.add_argument( 20 | '--log-interval', default=50, help='interval of logging') 21 | parser.add_argument( 22 | '--fuse-conv-bn', 23 | action='store_true', 24 | help='Whether to fuse conv and bn, this will slightly increase' 25 | 'the inference speed') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | 33 | cfg = Config.fromfile(args.config) 34 | # import modules from string list. 35 | if cfg.get('custom_imports', None): 36 | from mmcv.utils import import_modules_from_strings 37 | import_modules_from_strings(**cfg['custom_imports']) 38 | # set cudnn_benchmark 39 | if cfg.get('cudnn_benchmark', False): 40 | torch.backends.cudnn.benchmark = True 41 | cfg.model.pretrained = None 42 | cfg.data.test.test_mode = True 43 | 44 | # build the dataloader 45 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) 46 | if samples_per_gpu > 1: 47 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 48 | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) 49 | dataset = build_dataset(cfg.data.test) 50 | data_loader = build_dataloader( 51 | dataset, 52 | samples_per_gpu=1, 53 | workers_per_gpu=cfg.data.workers_per_gpu, 54 | dist=False, 55 | shuffle=False) 56 | 57 | # build the model and load checkpoint 58 | model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) 59 | fp16_cfg = cfg.get('fp16', None) 60 | if fp16_cfg is not None: 61 | wrap_fp16_model(model) 62 | load_checkpoint(model, args.checkpoint, map_location='cpu') 63 | if args.fuse_conv_bn: 64 | model = fuse_conv_bn(model) 65 | 66 | model = MMDataParallel(model, device_ids=[0]) 67 | 68 | model.eval() 69 | 70 | # the first several iterations may be very slow so skip them 71 | num_warmup = 5 72 | pure_inf_time = 0 73 | 74 | # benchmark with 2000 image and take the average 75 | for i, data in enumerate(data_loader): 76 | 77 | torch.cuda.synchronize() 78 | start_time = time.perf_counter() 79 | 80 | with torch.no_grad(): 81 | model(return_loss=False, rescale=True, **data) 82 | 83 | torch.cuda.synchronize() 84 | elapsed = time.perf_counter() - start_time 85 | 86 | if i >= num_warmup: 87 | pure_inf_time += elapsed 88 | if (i + 1) % args.log_interval == 0: 89 | fps = (i + 1 - num_warmup) / pure_inf_time 90 | print(f'Done image [{i + 1:<3}/ 2000], fps: {fps:.1f} img / s') 91 | 92 | if (i + 1) == 2000: 93 | pure_inf_time += elapsed 94 | fps = (i + 1 - num_warmup) / pure_inf_time 95 | print(f'Overall fps: {fps:.1f} img / s') 96 | break 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /tools/browse_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import mmcv 6 | from mmcv import Config 7 | 8 | from mmdet.datasets.builder import build_dataset 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='Browse a dataset') 13 | parser.add_argument('config', help='train config file path') 14 | parser.add_argument( 15 | '--skip-type', 16 | type=str, 17 | nargs='+', 18 | default=['DefaultFormatBundle', 'Normalize', 'Collect'], 19 | help='skip some useless pipeline') 20 | parser.add_argument( 21 | '--output-dir', 22 | default=None, 23 | type=str, 24 | help='If there is no display interface, you can save it') 25 | parser.add_argument('--not-show', default=False, action='store_true') 26 | parser.add_argument( 27 | '--show-interval', 28 | type=int, 29 | default=999, 30 | help='the interval of show (ms)') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def retrieve_data_cfg(config_path, skip_type): 36 | cfg = Config.fromfile(config_path) 37 | train_data_cfg = cfg.data.train 38 | train_data_cfg['pipeline'] = [ 39 | x for x in train_data_cfg.pipeline if x['type'] not in skip_type 40 | ] 41 | 42 | return cfg 43 | 44 | 45 | def main(): 46 | args = parse_args() 47 | cfg = retrieve_data_cfg(args.config, args.skip_type) 48 | 49 | dataset = build_dataset(cfg.data.train) 50 | 51 | progress_bar = mmcv.ProgressBar(len(dataset)) 52 | for item in dataset: 53 | filename = os.path.join(args.output_dir, 54 | Path(item['filename']).name 55 | ) if args.output_dir is not None else None 56 | mmcv.imshow_det_bboxes( 57 | item['img'], 58 | item['gt_bboxes'], 59 | item['gt_labels'], 60 | class_names=dataset.CLASSES, 61 | show=not args.not_show, 62 | out_file=filename, 63 | wait_time=args.show_interval) 64 | progress_bar.update() 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/eval_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mmcv 4 | from mmcv import Config, DictAction 5 | 6 | from mmdet.datasets import build_dataset 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='Evaluate metric of the ' 11 | 'results saved in pkl format') 12 | parser.add_argument('config', help='Config of the model') 13 | parser.add_argument('pkl_results', help='Results in pickle format') 14 | parser.add_argument( 15 | '--format-only', 16 | action='store_true', 17 | help='Format the output results without perform evaluation. It is' 18 | 'useful when you want to format the result to a specific format and ' 19 | 'submit it to the test server') 20 | parser.add_argument( 21 | '--eval', 22 | type=str, 23 | nargs='+', 24 | help='Evaluation metrics, which depends on the dataset, e.g., "bbox",' 25 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') 26 | parser.add_argument( 27 | '--cfg-options', 28 | nargs='+', 29 | action=DictAction, 30 | help='override some settings in the used config, the key-value pair ' 31 | 'in xxx=yyy format will be merged into config file.') 32 | parser.add_argument( 33 | '--eval-options', 34 | nargs='+', 35 | action=DictAction, 36 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 37 | 'format will be kwargs for dataset.evaluate() function') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def main(): 43 | args = parse_args() 44 | 45 | cfg = Config.fromfile(args.config) 46 | assert args.eval or args.format_only, ( 47 | 'Please specify at least one operation (eval/format the results) with ' 48 | 'the argument "--eval", "--format-only"') 49 | if args.eval and args.format_only: 50 | raise ValueError('--eval and --format_only cannot be both specified') 51 | 52 | if args.cfg_options is not None: 53 | cfg.merge_from_dict(args.cfg_options) 54 | cfg.data.test.test_mode = True 55 | 56 | dataset = build_dataset(cfg.data.test) 57 | outputs = mmcv.load(args.pkl_results) 58 | 59 | kwargs = {} if args.eval_options is None else args.eval_options 60 | if args.format_only: 61 | dataset.format_results(outputs, **kwargs) 62 | if args.eval: 63 | eval_kwargs = cfg.get('evaluation', {}).copy() 64 | # hard-code way to remove EvalHook args 65 | for key in [ 66 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 67 | 'rule' 68 | ]: 69 | eval_kwargs.pop(key, None) 70 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 71 | print(dataset.evaluate(outputs, **eval_kwargs)) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /tools/get_flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from mmcv import Config 5 | 6 | from mmdet.models import build_detector 7 | 8 | try: 9 | from mmcv.cnn import get_model_complexity_info 10 | except ImportError: 11 | raise ImportError('Please upgrade mmcv to >0.6.2') 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Train a detector') 16 | parser.add_argument('config', help='train config file path') 17 | parser.add_argument( 18 | '--shape', 19 | type=int, 20 | nargs='+', 21 | default=[1280, 800], 22 | help='input image size') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def main(): 28 | 29 | args = parse_args() 30 | 31 | if len(args.shape) == 1: 32 | input_shape = (3, args.shape[0], args.shape[0]) 33 | elif len(args.shape) == 2: 34 | input_shape = (3, ) + tuple(args.shape) 35 | else: 36 | raise ValueError('invalid input shape') 37 | 38 | cfg = Config.fromfile(args.config) 39 | # import modules from string list. 40 | if cfg.get('custom_imports', None): 41 | from mmcv.utils import import_modules_from_strings 42 | import_modules_from_strings(**cfg['custom_imports']) 43 | 44 | model = build_detector( 45 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 46 | if torch.cuda.is_available(): 47 | model.cuda() 48 | model.eval() 49 | 50 | if hasattr(model, 'forward_dummy'): 51 | model.forward = model.forward_dummy 52 | else: 53 | raise NotImplementedError( 54 | 'FLOPs counter is currently not currently supported with {}'. 55 | format(model.__class__.__name__)) 56 | 57 | flops, params = get_model_complexity_info(model, input_shape) 58 | split_line = '=' * 30 59 | print(f'{split_line}\nInput shape: {input_shape}\n' 60 | f'Flops: {flops}\nParams: {params}\n{split_line}') 61 | print('!!!Please be cautious if you use the results in papers. ' 62 | 'You may need to check if all ops are supported and verify that the ' 63 | 'flops computation is correct.') 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /tools/prepare.sh: -------------------------------------------------------------------------------- 1 | # download vgg16 imagenet pre-trained weights 2 | mkdir -p pretrain 3 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=10Vh2qFmGucO-9DZ3eY3HAvcAmtPFcFg2' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=10Vh2qFmGucO-9DZ3eY3HAvcAmtPFcFg2" -O pretrain/vgg16.pth && rm -rf /tmp/cookies.txt 4 | 5 | # download pascal voc selective search region proposals 6 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-1EJ3Mm7KoXwaurYx4zpPcerkKbIAqU-' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-1EJ3Mm7KoXwaurYx4zpPcerkKbIAqU-" -O data/voc_selective_search.zip && rm -rf /tmp/cookies.txt 7 | cd data && unzip voc_selective_search.zip && mv voc_selective_search/voc_2007* VOCdevkit/VOC2007/ && mv voc_selective_search/voc_2012* VOCdevkit/VOC2012/ && rm voc_selective_search.zip && rm -rf voc_selective_search 8 | 9 | # download pascal voc super pixels 10 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1tItyPAUz16iXOyIHpVfAmWbKfwQzkzoU' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1tItyPAUz16iXOyIHpVfAmWbKfwQzkzoU" -O data/voc_super_pixels.zip && rm -rf /tmp/cookies.txt 11 | cd data && unzip voc_super_pixels && mv voc_super_pixels/superpixel2007 VOCdevkit/VOC2007/SuperPixels && mv voc_super_pixels/superpixel2012 VOCdevkit/VOC2012/SuperPixels && rm voc_super_pixels.zip && rm -rf voc_super_pixels 12 | 13 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from mmcv import Config, DictAction 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Print the whole config') 8 | parser.add_argument('config', help='config file path') 9 | parser.add_argument( 10 | '--options', nargs='+', action=DictAction, help='arguments in dict') 11 | args = parser.parse_args() 12 | 13 | return args 14 | 15 | 16 | def main(): 17 | args = parse_args() 18 | 19 | cfg = Config.fromfile(args.config) 20 | if args.options is not None: 21 | cfg.merge_from_dict(args.options) 22 | print(f'Config:\n{cfg.pretty_text}') 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | import torch 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser( 9 | description='Process a checkpoint to be published') 10 | parser.add_argument('in_file', help='input checkpoint filename') 11 | parser.add_argument('out_file', help='output checkpoint filename') 12 | args = parser.parse_args() 13 | return args 14 | 15 | 16 | def process_checkpoint(in_file, out_file): 17 | checkpoint = torch.load(in_file, map_location='cpu') 18 | # remove optimizer for smaller file size 19 | if 'optimizer' in checkpoint: 20 | del checkpoint['optimizer'] 21 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 22 | # add the code here. 23 | torch.save(checkpoint, out_file) 24 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 25 | if out_file.endswith('.pth'): 26 | out_file_name = out_file[:-4] 27 | else: 28 | out_file_name = out_file 29 | final_file = out_file_name + f'-{sha[:8]}.pth' 30 | subprocess.Popen(['mv', out_file, final_file]) 31 | 32 | 33 | def main(): 34 | args = parse_args() 35 | process_checkpoint(args.in_file, args.out_file) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /tools/regnet2mmdet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | 4 | import torch 5 | 6 | 7 | def convert_stem(model_key, model_weight, state_dict, converted_names): 8 | new_key = model_key.replace('stem.conv', 'conv1') 9 | new_key = new_key.replace('stem.bn', 'bn1') 10 | state_dict[new_key] = model_weight 11 | converted_names.add(model_key) 12 | print(f'Convert {model_key} to {new_key}') 13 | 14 | 15 | def convert_head(model_key, model_weight, state_dict, converted_names): 16 | new_key = model_key.replace('head.fc', 'fc') 17 | state_dict[new_key] = model_weight 18 | converted_names.add(model_key) 19 | print(f'Convert {model_key} to {new_key}') 20 | 21 | 22 | def convert_reslayer(model_key, model_weight, state_dict, converted_names): 23 | split_keys = model_key.split('.') 24 | layer, block, module = split_keys[:3] 25 | block_id = int(block[1:]) 26 | layer_name = f'layer{int(layer[1:])}' 27 | block_name = f'{block_id - 1}' 28 | 29 | if block_id == 1 and module == 'bn': 30 | new_key = f'{layer_name}.{block_name}.downsample.1.{split_keys[-1]}' 31 | elif block_id == 1 and module == 'proj': 32 | new_key = f'{layer_name}.{block_name}.downsample.0.{split_keys[-1]}' 33 | elif module == 'f': 34 | if split_keys[3] == 'a_bn': 35 | module_name = 'bn1' 36 | elif split_keys[3] == 'b_bn': 37 | module_name = 'bn2' 38 | elif split_keys[3] == 'c_bn': 39 | module_name = 'bn3' 40 | elif split_keys[3] == 'a': 41 | module_name = 'conv1' 42 | elif split_keys[3] == 'b': 43 | module_name = 'conv2' 44 | elif split_keys[3] == 'c': 45 | module_name = 'conv3' 46 | new_key = f'{layer_name}.{block_name}.{module_name}.{split_keys[-1]}' 47 | else: 48 | raise ValueError(f'Unsupported conversion of key {model_key}') 49 | print(f'Convert {model_key} to {new_key}') 50 | state_dict[new_key] = model_weight 51 | converted_names.add(model_key) 52 | 53 | 54 | def convert(src, dst): 55 | """Convert keys in pycls pretrained RegNet models to mmdet style.""" 56 | # load caffe model 57 | regnet_model = torch.load(src) 58 | blobs = regnet_model['model_state'] 59 | # convert to pytorch style 60 | state_dict = OrderedDict() 61 | converted_names = set() 62 | for key, weight in blobs.items(): 63 | if 'stem' in key: 64 | convert_stem(key, weight, state_dict, converted_names) 65 | elif 'head' in key: 66 | convert_head(key, weight, state_dict, converted_names) 67 | elif key.startswith('s'): 68 | convert_reslayer(key, weight, state_dict, converted_names) 69 | 70 | # check if all layers are converted 71 | for key in blobs: 72 | if key not in converted_names: 73 | print(f'not converted: {key}') 74 | # save checkpoint 75 | checkpoint = dict() 76 | checkpoint['state_dict'] = state_dict 77 | torch.save(checkpoint, dst) 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser(description='Convert model keys') 82 | parser.add_argument('src', help='src detectron model path') 83 | parser.add_argument('dst', help='save path') 84 | args = parser.parse_args() 85 | convert(args.src, args.dst) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 25 | --------------------------------------------------------------------------------