├── .dev_scripts
├── batch_test.py
├── batch_test.sh
├── benchmark_filter.py
├── gather_models.py
└── linter.sh
├── .github
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── error-report.md
│ ├── feature_request.md
│ ├── general_questions.md
│ └── reimplementation_questions.md
└── workflows
│ ├── build.yml
│ └── deploy.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── LICENSE
├── README.md
├── configs
└── wsod
│ ├── base.py
│ ├── oicr_bbox_vgg16.py
│ ├── oicr_vgg16.py
│ ├── wsddn_vgg16.py
│ └── wsod2_vgg16.py
├── demo
├── MMDet_Tutorial.ipynb
├── demo.jpg
├── image_demo.py
├── inference_demo.ipynb
└── webcam_demo.py
├── docker
└── Dockerfile
├── docs
├── 1_exist_data_model.md
├── 2_new_data_model.md
├── 3_exist_data_new_model.md
├── Makefile
├── api.rst
├── changelog.md
├── compatibility.md
├── conf.py
├── conventions.md
├── faq.md
├── get_started.md
├── index.rst
├── make.bat
├── model_zoo.md
├── projects.md
├── robustness_benchmarking.md
├── stat.py
├── tutorials
│ ├── config.md
│ ├── customize_dataset.md
│ ├── customize_losses.md
│ ├── customize_models.md
│ ├── customize_runtime.md
│ ├── data_pipeline.md
│ ├── finetune.md
│ └── index.rst
└── useful_tools.md
├── mmdet
├── __init__.py
├── apis
│ ├── __init__.py
│ ├── inference.py
│ ├── test.py
│ └── train.py
├── core
│ ├── __init__.py
│ ├── anchor
│ │ ├── __init__.py
│ │ ├── anchor_generator.py
│ │ ├── builder.py
│ │ ├── point_generator.py
│ │ └── utils.py
│ ├── bbox
│ │ ├── __init__.py
│ │ ├── assigners
│ │ │ ├── __init__.py
│ │ │ ├── approx_max_iou_assigner.py
│ │ │ ├── assign_result.py
│ │ │ ├── atss_assigner.py
│ │ │ ├── base_assigner.py
│ │ │ ├── center_region_assigner.py
│ │ │ ├── grid_assigner.py
│ │ │ ├── max_iou_assigner.py
│ │ │ └── point_assigner.py
│ │ ├── builder.py
│ │ ├── coder
│ │ │ ├── __init__.py
│ │ │ ├── base_bbox_coder.py
│ │ │ ├── bucketing_bbox_coder.py
│ │ │ ├── delta_xywh_bbox_coder.py
│ │ │ ├── legacy_delta_xywh_bbox_coder.py
│ │ │ ├── pseudo_bbox_coder.py
│ │ │ ├── tblr_bbox_coder.py
│ │ │ └── yolo_bbox_coder.py
│ │ ├── demodata.py
│ │ ├── iou_calculators
│ │ │ ├── __init__.py
│ │ │ ├── builder.py
│ │ │ └── iou2d_calculator.py
│ │ ├── samplers
│ │ │ ├── __init__.py
│ │ │ ├── base_sampler.py
│ │ │ ├── combined_sampler.py
│ │ │ ├── instance_balanced_pos_sampler.py
│ │ │ ├── iou_balanced_neg_sampler.py
│ │ │ ├── ohem_sampler.py
│ │ │ ├── pseudo_sampler.py
│ │ │ ├── random_sampler.py
│ │ │ ├── sampling_result.py
│ │ │ └── score_hlr_sampler.py
│ │ └── transforms.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── bbox_overlaps.py
│ │ ├── class_names.py
│ │ ├── eval_hooks.py
│ │ ├── mean_ap.py
│ │ └── recall.py
│ ├── export
│ │ ├── __init__.py
│ │ └── pytorch2onnx.py
│ ├── fp16
│ │ ├── __init__.py
│ │ └── deprecated_fp16_utils.py
│ ├── mask
│ │ ├── __init__.py
│ │ ├── mask_target.py
│ │ ├── structures.py
│ │ └── utils.py
│ ├── post_processing
│ │ ├── __init__.py
│ │ ├── bbox_nms.py
│ │ └── merge_augs.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── dist_utils.py
│ │ └── misc.py
├── datasets
│ ├── __init__.py
│ ├── builder.py
│ ├── cityscapes.py
│ ├── coco.py
│ ├── custom.py
│ ├── dataset_wrappers.py
│ ├── deepfashion.py
│ ├── lvis.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── auto_augment.py
│ │ ├── compose.py
│ │ ├── formating.py
│ │ ├── instaboost.py
│ │ ├── loading.py
│ │ ├── test_time_aug.py
│ │ └── transforms.py
│ ├── samplers
│ │ ├── __init__.py
│ │ ├── distributed_sampler.py
│ │ └── group_sampler.py
│ ├── utils.py
│ ├── voc.py
│ ├── voc_ss.py
│ ├── wider_face.py
│ └── xml_style.py
├── models
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── darknet.py
│ │ ├── detectors_resnet.py
│ │ ├── detectors_resnext.py
│ │ ├── hourglass.py
│ │ ├── hrnet.py
│ │ ├── regnet.py
│ │ ├── res2net.py
│ │ ├── resnest.py
│ │ ├── resnet.py
│ │ ├── resnext.py
│ │ ├── ssd_vgg.py
│ │ └── vgg.py
│ ├── builder.py
│ ├── dense_heads
│ │ ├── __init__.py
│ │ ├── anchor_free_head.py
│ │ ├── anchor_head.py
│ │ ├── atss_head.py
│ │ ├── base_dense_head.py
│ │ ├── centripetal_head.py
│ │ ├── corner_head.py
│ │ ├── dense_test_mixins.py
│ │ ├── fcos_head.py
│ │ ├── fovea_head.py
│ │ ├── free_anchor_retina_head.py
│ │ ├── fsaf_head.py
│ │ ├── ga_retina_head.py
│ │ ├── ga_rpn_head.py
│ │ ├── gfl_head.py
│ │ ├── guided_anchor_head.py
│ │ ├── nasfcos_head.py
│ │ ├── paa_head.py
│ │ ├── pisa_retinanet_head.py
│ │ ├── pisa_ssd_head.py
│ │ ├── reppoints_head.py
│ │ ├── retina_head.py
│ │ ├── retina_sepbn_head.py
│ │ ├── rpn_head.py
│ │ ├── rpn_test_mixin.py
│ │ ├── sabl_retina_head.py
│ │ ├── ssd_head.py
│ │ ├── vfnet_head.py
│ │ ├── yolact_head.py
│ │ └── yolo_head.py
│ ├── detectors
│ │ ├── __init__.py
│ │ ├── atss.py
│ │ ├── base.py
│ │ ├── cascade_rcnn.py
│ │ ├── cornernet.py
│ │ ├── fast_rcnn.py
│ │ ├── faster_rcnn.py
│ │ ├── fcos.py
│ │ ├── fovea.py
│ │ ├── fsaf.py
│ │ ├── gfl.py
│ │ ├── grid_rcnn.py
│ │ ├── htc.py
│ │ ├── mask_rcnn.py
│ │ ├── mask_scoring_rcnn.py
│ │ ├── nasfcos.py
│ │ ├── paa.py
│ │ ├── point_rend.py
│ │ ├── reppoints_detector.py
│ │ ├── retinanet.py
│ │ ├── rpn.py
│ │ ├── single_stage.py
│ │ ├── two_stage.py
│ │ ├── vfnet.py
│ │ ├── weak_rcnn.py
│ │ ├── yolact.py
│ │ └── yolo.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── accuracy.py
│ │ ├── ae_loss.py
│ │ ├── balanced_l1_loss.py
│ │ ├── cross_entropy_loss.py
│ │ ├── focal_loss.py
│ │ ├── gaussian_focal_loss.py
│ │ ├── gfocal_loss.py
│ │ ├── ghm_loss.py
│ │ ├── iou_loss.py
│ │ ├── mse_loss.py
│ │ ├── pisa_loss.py
│ │ ├── smooth_l1_loss.py
│ │ ├── utils.py
│ │ └── varifocal_loss.py
│ ├── necks
│ │ ├── __init__.py
│ │ ├── bfp.py
│ │ ├── channel_mapper.py
│ │ ├── fpn.py
│ │ ├── fpn_carafe.py
│ │ ├── hrfpn.py
│ │ ├── nas_fpn.py
│ │ ├── nasfcos_fpn.py
│ │ ├── pafpn.py
│ │ ├── rfp.py
│ │ └── yolo_neck.py
│ ├── roi_heads
│ │ ├── __init__.py
│ │ ├── base_roi_head.py
│ │ ├── bbox_heads
│ │ │ ├── __init__.py
│ │ │ ├── bbox_head.py
│ │ │ ├── convfc_bbox_head.py
│ │ │ ├── double_bbox_head.py
│ │ │ ├── oicr_head.py
│ │ │ ├── sabl_head.py
│ │ │ └── wsddn_head.py
│ │ ├── cascade_roi_head.py
│ │ ├── double_roi_head.py
│ │ ├── dynamic_roi_head.py
│ │ ├── grid_roi_head.py
│ │ ├── htc_roi_head.py
│ │ ├── mask_heads
│ │ │ ├── __init__.py
│ │ │ ├── coarse_mask_head.py
│ │ │ ├── fcn_mask_head.py
│ │ │ ├── fused_semantic_head.py
│ │ │ ├── grid_head.py
│ │ │ ├── htc_mask_head.py
│ │ │ ├── mask_point_head.py
│ │ │ └── maskiou_head.py
│ │ ├── mask_scoring_roi_head.py
│ │ ├── oicr_roi_head.py
│ │ ├── pisa_roi_head.py
│ │ ├── point_rend_roi_head.py
│ │ ├── roi_extractors
│ │ │ ├── __init__.py
│ │ │ ├── base_roi_extractor.py
│ │ │ ├── generic_roi_extractor.py
│ │ │ └── single_level_roi_extractor.py
│ │ ├── shared_heads
│ │ │ ├── __init__.py
│ │ │ └── res_layer.py
│ │ ├── standard_roi_head.py
│ │ ├── test_mixins.py
│ │ ├── wsddn_roi_head.py
│ │ └── wsod2_roi_head.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── gaussian_target.py
│ │ └── res_layer.py
├── ops
│ └── __init__.py
├── utils
│ ├── __init__.py
│ ├── collect_env.py
│ ├── contextmanagers.py
│ ├── logger.py
│ ├── profiling.py
│ └── util_mixins.py
└── version.py
├── pytest.ini
├── requirements.txt
├── requirements
├── build.txt
├── docs.txt
├── optional.txt
├── readthedocs.txt
├── runtime.txt
└── tests.txt
├── resources
└── architecture.png
├── setup.cfg
├── setup.py
├── tests
├── async_benchmark.py
├── data
│ ├── coco_sample.json
│ ├── color.jpg
│ └── gray.jpg
├── test_anchor.py
├── test_assigner.py
├── test_async.py
├── test_coder.py
├── test_config.py
├── test_data
│ ├── test_dataset.py
│ ├── test_formatting.py
│ ├── test_img_augment.py
│ ├── test_loading.py
│ ├── test_models_aug_test.py
│ ├── test_rotate.py
│ ├── test_sampler.py
│ ├── test_shear.py
│ ├── test_transform.py
│ ├── test_translate.py
│ └── test_utils.py
├── test_eval_hook.py
├── test_fp16.py
├── test_iou2d_calculator.py
├── test_masks.py
├── test_models
│ ├── test_backbones.py
│ ├── test_forward.py
│ ├── test_heads.py
│ ├── test_losses.py
│ ├── test_necks.py
│ ├── test_pisa_heads.py
│ └── test_roi_extractor.py
└── test_version.py
└── tools
├── analyze_logs.py
├── benchmark.py
├── browse_dataset.py
├── coco_error_analysis.py
├── convert_datasets
├── cityscapes.py
└── pascal_voc.py
├── detectron2pytorch.py
├── dist_test.sh
├── dist_train.sh
├── eval_metric.py
├── get_flops.py
├── prepare.sh
├── print_config.py
├── publish_model.py
├── pytorch2onnx.py
├── regnet2mmdet.py
├── robustness_eval.py
├── slurm_test.sh
├── slurm_train.sh
├── test.py
├── test_robustness.py
├── train.py
└── upgrade_model_version.py
/.dev_scripts/batch_test.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=${PWD}
2 |
3 | partition=$1
4 | model_dir=$2
5 | json_out=$3
6 | job_name=batch_test
7 | gpus=8
8 | gpu_per_node=8
9 |
10 | touch $json_out
11 | lastLine=$(tail -n 1 $json_out)
12 | while [ "$lastLine" != "finished" ]
13 | do
14 | srun -p ${partition} --gres=gpu:${gpu_per_node} -n${gpus} --ntasks-per-node=${gpu_per_node} \
15 | --job-name=${job_name} --kill-on-bad-exit=1 \
16 | python .dev_scripts/batch_test.py $model_dir $json_out --launcher='slurm'
17 | lastLine=$(tail -n 1 $json_out)
18 | echo $lastLine
19 | done
20 |
--------------------------------------------------------------------------------
/.dev_scripts/linter.sh:
--------------------------------------------------------------------------------
1 | yapf -r -i mmdet/ configs/ tests/ tools/
2 | isort -rc mmdet/ configs/ tests/ tools/
3 | flake8 .
4 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at chenkaidev@gmail.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to mmdetection
2 |
3 | All kinds of contributions are welcome, including but not limited to the following.
4 |
5 | - Fixes (typo, bugs)
6 | - New features and components
7 |
8 | ## Workflow
9 |
10 | 1. fork and pull the latest mmdetection
11 | 2. checkout a new branch (do not use master branch for PRs)
12 | 3. commit your changes
13 | 4. create a PR
14 |
15 | Note
16 | - If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
17 | - If you are the author of some papers and would like to include your method to mmdetection,
18 | please let us know (open an issue or contact the maintainers). We will much appreciate your contribution.
19 | - For new features and new modules, unit tests are required to improve the code's robustness.
20 |
21 | ## Code style
22 |
23 | ### Python
24 | We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
25 |
26 | We use the following tools for linting and formatting:
27 | - [flake8](http://flake8.pycqa.org/en/latest/): linter
28 | - [yapf](https://github.com/google/yapf): formatter
29 | - [isort](https://github.com/timothycrosley/isort): sort imports
30 |
31 | Style configurations of yapf and isort can be found in [setup.cfg](../setup.cfg).
32 |
33 | We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`,
34 | fixes `end-of-files`, sorts `requirments.txt` automatically on every commit.
35 | The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml).
36 |
37 | After you clone the repository, you will need to install initialize pre-commit hook.
38 |
39 | ```
40 | pip install -U pre-commit
41 | ```
42 |
43 | From the repository folder
44 | ```
45 | pre-commit install
46 | ```
47 |
48 | After this on every commit check code linters and formatter will be enforced.
49 |
50 |
51 | >Before you create a PR, make sure that your code lints and is formatted by yapf.
52 |
53 | ### C++ and CUDA
54 | We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
55 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
3 | contact_links:
4 | - name: Common Issues
5 | url: https://mmdetection.readthedocs.io/en/latest/faq.html
6 | about: Check if your issue already has solutions
7 | - name: MMDetection Documentation
8 | url: https://mmdetection.readthedocs.io/en/latest/
9 | about: Check if your question is answered in docs
10 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/error-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Error report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | Thanks for your error report and we appreciate it a lot.
11 |
12 | **Checklist**
13 | 1. I have searched related issues but cannot get the expected help.
14 | 2. The bug has not been fixed in the latest version.
15 |
16 | **Describe the bug**
17 | A clear and concise description of what the bug is.
18 |
19 | **Reproduction**
20 | 1. What command or script did you run?
21 | ```
22 | A placeholder for the command.
23 | ```
24 | 2. Did you make any modifications on the code or config? Did you understand what you have modified?
25 | 3. What dataset did you use?
26 |
27 | **Environment**
28 |
29 | 1. Please run `python mmdet/utils/collect_env.py` to collect necessary environment information and paste it here.
30 | 2. You may add addition that may be helpful for locating the problem, such as
31 | - How you installed PyTorch [e.g., pip, conda, source]
32 | - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
33 |
34 | **Error traceback**
35 | If applicable, paste the error trackback here.
36 | ```
37 | A placeholder for trackback.
38 | ```
39 |
40 | **Bug fix**
41 | If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
42 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the feature**
11 |
12 | **Motivation**
13 | A clear and concise description of the motivation of the feature.
14 | Ex1. It is inconvenient when [....].
15 | Ex2. There is a recent paper [....], which is very helpful for [....].
16 |
17 | **Related resources**
18 | If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
19 |
20 | **Additional context**
21 | Add any other context or screenshots about the feature request here.
22 | If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
23 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/general_questions.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: General questions
3 | about: Ask general questions to get help
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/reimplementation_questions.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Reimplementation Questions
3 | about: Ask about questions during model reimplementation
4 | title: ''
5 | labels: 'reimplementation'
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Notice**
11 |
12 | There are several common situations in the reimplementation issues as below
13 | 1. Reimplement a model in the model zoo using the provided configs
14 | 2. Reimplement a model in the model zoo on other dataset (e.g., custom datasets)
15 | 3. Reimplement a custom model but all the components are implemented in MMDetection
16 | 4. Reimplement a custom model with new modules implemented by yourself
17 |
18 | There are several things to do for different cases as below.
19 | - For case 1 & 3, please follow the steps in the following sections thus we could help to quick identify the issue.
20 | - For case 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code and the users should be responsible to the code they write.
21 | - One suggestion for case 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections and try as clear as possible so that we can better help you.
22 |
23 | **Checklist**
24 | 1. I have searched related issues but cannot get the expected help.
25 | 2. The issue has not been fixed in the latest version.
26 |
27 | **Describe the issue**
28 |
29 | A clear and concise description of what the problem you meet and what have you done.
30 |
31 | **Reproduction**
32 | 1. What command or script did you run?
33 | ```
34 | A placeholder for the command.
35 | ```
36 | 2. What config dir you run?
37 | ```
38 | A placeholder for the config.
39 | ```
40 | 3. Did you make any modifications on the code or config? Did you understand what you have modified?
41 | 4. What dataset did you use?
42 |
43 | **Environment**
44 |
45 | 1. Please run `python mmdet/utils/collect_env.py` to collect necessary environment information and paste it here.
46 | 2. You may add addition that may be helpful for locating the problem, such as
47 | - How you installed PyTorch [e.g., pip, conda, source]
48 | - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
49 |
50 | **Results**
51 |
52 | If applicable, paste the related results here, e.g., what you expect and what you get.
53 | ```
54 | A placeholder for results comparison
55 | ```
56 |
57 | **Issue fix**
58 |
59 | If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
60 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: deploy
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | runs-on: ubuntu-latest
8 | if: startsWith(github.event.ref, 'refs/tags')
9 | steps:
10 | - uses: actions/checkout@v2
11 | - name: Set up Python 3.7
12 | uses: actions/setup-python@v2
13 | with:
14 | python-version: 3.7
15 | - name: Install torch
16 | run: pip install torch
17 | - name: Install wheel
18 | run: pip install wheel
19 | - name: Build MMDetection
20 | run: python setup.py sdist bdist_wheel
21 | - name: Publish distribution to PyPI
22 | run: |
23 | pip install twine
24 | twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | data/
107 | data
108 | .vscode
109 | .idea
110 | .DS_Store
111 |
112 | # custom
113 | *.pkl
114 | *.pkl.json
115 | *.log.json
116 | work_dirs/
117 |
118 | # Pytorch
119 | *.pth
120 | *.py~
121 | *.sh~
122 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://gitlab.com/pycqa/flake8.git
3 | rev: 3.8.3
4 | hooks:
5 | - id: flake8
6 | - repo: https://github.com/asottile/seed-isort-config
7 | rev: v2.2.0
8 | hooks:
9 | - id: seed-isort-config
10 | - repo: https://github.com/timothycrosley/isort
11 | rev: 4.3.21
12 | hooks:
13 | - id: isort
14 | - repo: https://github.com/pre-commit/mirrors-yapf
15 | rev: v0.30.0
16 | hooks:
17 | - id: yapf
18 | - repo: https://github.com/pre-commit/pre-commit-hooks
19 | rev: v3.1.0
20 | hooks:
21 | - id: trailing-whitespace
22 | - id: check-yaml
23 | - id: end-of-file-fixer
24 | - id: requirements-txt-fixer
25 | - id: double-quote-string-fixer
26 | - id: check-merge-conflict
27 | - id: fix-encoding-pragma
28 | args: ["--remove"]
29 | - id: mixed-line-ending
30 | args: ["--fix=lf"]
31 | - repo: https://github.com/myint/docformatter
32 | rev: v1.3.1
33 | hooks:
34 | - id: docformatter
35 | args: ["--in-place", "--wrap-descriptions", "79"]
36 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | python:
4 | version: 3.7
5 | install:
6 | - requirements: requirements/docs.txt
7 | - requirements: requirements/readthedocs.txt
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WSOD^2: Learning Bottom-up and Top-down Objectness Distillation for Weakly-supervised Object Detection
2 | By Zhaoyang Zeng, Bei Liu, Jianlong Fu, Hongyang Chao, and Lei Zhang
3 |
4 | ### Introduction
5 | This repo is a toolkit for weakly supervised object detection based on [mmdetection](https://github.com/open-mmlab/mmdetection), including the implementation of [WSDDN](https://arxiv.org/abs/1511.02853), [OICR](https://arxiv.org/abs/1704.00138) and [WSOD^2](https://arxiv.org/abs/1909.04972). The implementation is slightly different from the original papers, including but not limited to
6 | * optimizer
7 | * training epoch
8 | * learning rate
9 | * input resolution
10 | * pseudo GTs mining
11 | * loss weight assignment
12 |
13 | The baselines in this rpo can easily achieve 48+ mAP on Pascal VOC 2007 dataset. Some hyperparameters are still tuned, they should bring more performance gain.
14 |
15 | ### Architecture
16 |
17 |
18 |
19 |
20 |
21 | ### Results
22 |
23 | | Method | VOC2007 test *mAP* | VOC2007 trainval *CorLoc* | VOC2012 test *mAP* | VOC2012 trainval *CorLoc*
24 | |:-------|:-----:|:-------:|:-------:|:-------:|
25 | | WSOD2 | 53.6 | 71.4 | 47.2 | 71.9 |
26 | | WSOD2\* | 56.0 | 71.4 | 52.7 | 72.2 |
27 |
28 | \* denotes training on VOC 07+12 *trainval* splits
29 |
30 | ### Installation
31 |
32 | Please refere to [here](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md) for installation
33 |
34 | ### Getting Started
35 |
36 | 1. Download the training, validation and test data, and unzip
37 | ```shell
38 | mkdir -p $WSOD_ROOT/data/voc
39 | cd $WSOD_ROOT/data/voc
40 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
41 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
42 | tar xf VOCtrainval_06-Nov-2007.tar
43 | tar xf VOCtest_06-Nov-2007.tar
44 | ```
45 |
46 | 2. Download the ImageNet pre-trained models, selective search boxes and superpixels
47 | ```shell
48 | bash $WSOD_ROOT/tools/prepare.sh
49 | ```
50 |
51 | If you can not access google drive, you also can download the resources from [https://pan.baidu.com/s/1htyljhvYz5qwO-4oH8C3wg](https://pan.baidu.com/s/1htyljhvYz5qwO-4oH8C3wg) (password: u5r3) and unzip them, the directory structure should be like
52 | ```
53 | data
54 | - VOCdevkit
55 | - VOC2007
56 | - voc_2007_trainval.pkl
57 | - voc_2007_test.pkl
58 | - SuperPixels
59 | - VOC2012
60 | - voc_2012_trainval.pkl
61 | - voc_2012_test.pkl
62 | - SuperPixels
63 | pretrain
64 | - vgg16.pth
65 | ```
66 |
67 | 3. Training a wsod model
68 | ```shell
69 | bash tools/dist_train.sh $config $num_gpus
70 | ```
71 |
72 | 4. Evaluate a wsod model
73 | ```shell
74 | bash tools/dist_test.sh $config $checkpoint $num_gpus --eval mAP
75 | ```
76 |
77 | ### License
78 | WSOD2 is released under the MIT License.
79 |
80 | ### Citing WSOD2
81 |
82 | If your find this repo useful in your research, please consider citing:
83 |
84 | ```BibTex
85 | @inproceedings{zeng2019wsod2,
86 | title={Wsod2: Learning bottom-up and top-down objectness distillation for weakly-supervised object detection},
87 | author={Zeng, Zhaoyang and Liu, Bei and Fu, Jianlong and Chao, Hongyang and Zhang, Lei},
88 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
89 | pages={8292--8300},
90 | year={2019}
91 | }
92 | ```
93 |
94 |
95 |
--------------------------------------------------------------------------------
/configs/wsod/base.py:
--------------------------------------------------------------------------------
1 | # model training and testing settings
2 | train_cfg = dict(
3 | rcnn=dict())
4 | test_cfg = dict(
5 | rcnn=dict(
6 | score_thr=0.0000,
7 | nms=dict(type='nms', iou_threshold=0.3),
8 | max_per_img=100))
9 |
10 | # dataset settings
11 | dataset_type = 'VOCDataset'
12 | data_root = '/datavoc/VOCdevkit/'
13 | img_norm_cfg = dict(
14 | mean=[104., 117., 124.], std=[1., 1., 1.], to_rgb=False)
15 | train_pipeline = [
16 | dict(type='LoadImageFromFile'),
17 | dict(type='LoadWeakAnnotations'),
18 | dict(type='LoadProposals'),
19 | dict(type='Resize', img_scale=[(488, 2000), (576, 2000), (688, 2000), (864, 2000), (1200, 2000)], keep_ratio=True, multiscale_mode='range'),
20 | dict(type='RandomFlip', flip_ratio=0.5),
21 | dict(type='Normalize', **img_norm_cfg),
22 | dict(type='Pad', size_divisor=32),
23 | dict(type='DefaultFormatBundle'),
24 | dict(type='Collect', keys=['img', 'gt_labels', 'proposals']),
25 | ]
26 | test_pipeline = [
27 | dict(type='LoadImageFromFile'),
28 | dict(type='LoadProposals'),
29 | dict(
30 | type='MultiScaleFlipAug',
31 | img_scale=(688, 2000),
32 | #img_scale=[(500, 2000), (600, 2000), (700, 2000), (800, 2000), (900, 2000)],
33 | flip=False,
34 | transforms=[
35 | dict(type='Resize', keep_ratio=True),
36 | dict(type='RandomFlip'),
37 | dict(type='Normalize', **img_norm_cfg),
38 | dict(type='Pad', size_divisor=32),
39 | dict(type='ImageToTensor', keys=['img']),
40 | dict(type='Collect', keys=['img', 'proposals']),
41 | ])
42 | ]
43 | data = dict(
44 | samples_per_gpu=1,
45 | workers_per_gpu=2,
46 | train=dict(
47 | type=dataset_type,
48 | ann_file=data_root + 'VOC2007/ImageSets/Main/trainval.txt',
49 | img_prefix=data_root + 'VOC2007/',
50 | proposal_file='/datavoc/selective_search_data/voc_2007_trainval.pkl',
51 | pipeline=train_pipeline),
52 | val=dict(
53 | type=dataset_type,
54 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
55 | img_prefix=data_root + 'VOC2007/',
56 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl',
57 | pipeline=test_pipeline),
58 | test=dict(
59 | type=dataset_type,
60 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
61 | img_prefix=data_root + 'VOC2007/',
62 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl',
63 | pipeline=test_pipeline))
64 | evaluation = dict(interval=1, metric='mAP')
65 |
66 | # optimizer
67 | optimizer = dict(
68 | type='Adam',
69 | lr=1e-5,
70 | weight_decay=0.0005,
71 | paramwise_cfg=dict(
72 | bias_decay_mult=0.,
73 | bias_lr_mult=2.,
74 | custom_keys={
75 | 'refine': dict(lr_mult=10),
76 | })
77 | )
78 |
79 | optimizer_config = dict(grad_clip=None)
80 | # learning policy
81 | lr_config = dict(
82 | policy='step',
83 | warmup='linear',
84 | warmup_iters=500,
85 | warmup_ratio=0.001,
86 | step=[36])
87 | total_epochs = 64
88 |
89 | checkpoint_config = dict(interval=16)
90 | # yapf:disable
91 | log_config = dict(
92 | interval=100,
93 | hooks=[
94 | dict(type='TextLoggerHook'),
95 | # dict(type='TensorboardLoggerHook')
96 | ])
97 | # yapf:enable
98 | dist_params = dict(backend='nccl')
99 | log_level = 'INFO'
100 | load_from = 'pretrain/vgg16_v2.pth'
101 | resume_from = None
102 | workflow = [('train', 1)]
103 |
--------------------------------------------------------------------------------
/configs/wsod/oicr_bbox_vgg16.py:
--------------------------------------------------------------------------------
1 | _base_ = './base.py'
2 | # model settings
3 | model = dict(
4 | type='WeakRCNN',
5 | pretrained=None,
6 | backbone=dict(type='VGG16'),
7 | neck=None,
8 | roi_head=dict(
9 | type='OICRRoIHead',
10 | bbox_roi_extractor=dict(
11 | type='SingleRoIExtractor',
12 | roi_layer=dict(type='RoIPool', output_size=7),
13 | out_channels=512,
14 | featmap_strides=[8]),
15 | bbox_head=dict(
16 | type='OICRHead',
17 | in_channels=512,
18 | hidden_channels=4096,
19 | roi_feat_size=7,
20 | bbox_coder=dict(
21 | type='DeltaXYWHBBoxCoder',
22 | target_means=[0., 0., 0., 0.],
23 | target_stds=[0.1, 0.1, 0.2, 0.2]),
24 | num_classes=20))
25 | )
26 | work_dir = 'work_dirs/oicr_bbox_vgg16/'
27 |
--------------------------------------------------------------------------------
/configs/wsod/oicr_vgg16.py:
--------------------------------------------------------------------------------
1 | _base_ = './base.py'
2 | # model settings
3 | model = dict(
4 | type='WeakRCNN',
5 | pretrained=None,
6 | backbone=dict(type='VGG16'),
7 | neck=None,
8 | roi_head=dict(
9 | type='OICRRoIHead',
10 | bbox_roi_extractor=dict(
11 | type='SingleRoIExtractor',
12 | roi_layer=dict(type='RoIPool', output_size=7),
13 | out_channels=512,
14 | featmap_strides=[8]),
15 | bbox_head=dict(
16 | type='OICRHead',
17 | in_channels=512,
18 | hidden_channels=4096,
19 | roi_feat_size=7,
20 | num_classes=20))
21 | )
22 | work_dir = 'work_dirs/oicr_vgg16/'
23 |
--------------------------------------------------------------------------------
/configs/wsod/wsddn_vgg16.py:
--------------------------------------------------------------------------------
1 | _base_ = './base.py'
2 | # model settings
3 | model = dict(
4 | type='WeakRCNN',
5 | pretrained=None,
6 | backbone=dict(type='VGG16'),
7 | neck=None,
8 | roi_head=dict(
9 | type='WSDDNRoIHead',
10 | bbox_roi_extractor=dict(
11 | type='SingleRoIExtractor',
12 | roi_layer=dict(type='RoIPool', output_size=7),
13 | out_channels=512,
14 | featmap_strides=[8]),
15 | bbox_head=dict(
16 | type='WSDDNHead',
17 | in_channels=512,
18 | hidden_channels=4096,
19 | roi_feat_size=7,
20 | num_classes=20))
21 | )
22 | work_dir = 'work_dirs/wsddn_vgg16/'
23 |
--------------------------------------------------------------------------------
/configs/wsod/wsod2_vgg16.py:
--------------------------------------------------------------------------------
1 | _base_ = './base.py'
2 | # model settings
3 | model = dict(
4 | type='WeakRCNN',
5 | pretrained=None,
6 | backbone=dict(type='VGG16'),
7 | neck=None,
8 | roi_head=dict(
9 | type='WSOD2RoIHead',
10 | steps=40000,
11 | bbox_roi_extractor=dict(
12 | type='SingleRoIExtractor',
13 | roi_layer=dict(type='RoIPool', output_size=7),
14 | out_channels=512,
15 | featmap_strides=[8]),
16 | bbox_head=dict(
17 | type='OICRHead',
18 | in_channels=512,
19 | hidden_channels=4096,
20 | roi_feat_size=7,
21 | bbox_coder=dict(
22 | type='DeltaXYWHBBoxCoder',
23 | target_means=[0., 0., 0., 0.],
24 | target_stds=[0.1, 0.1, 0.2, 0.2]),
25 | num_classes=20))
26 | )
27 | # dataset settings
28 | dataset_type = 'VOCSSDataset'
29 | data_root = '/datavoc/VOCdevkit/'
30 | img_norm_cfg = dict(
31 | mean=[104., 117., 124.], std=[1., 1., 1.], to_rgb=False)
32 | train_pipeline = [
33 | dict(type='LoadImageFromFile'),
34 | dict(type='LoadSuperPixelFromFile'),
35 | dict(type='LoadWeakAnnotations'),
36 | dict(type='LoadProposals'),
37 | dict(type='Resize', img_scale=[(488, 2000), (576, 2000), (688, 2000), (864, 2000), (1200, 2000)], keep_ratio=True, multiscale_mode='value'),
38 | dict(type='RandomFlip', flip_ratio=0.5),
39 | dict(type='Normalize', **img_norm_cfg),
40 | dict(type='Pad', size_divisor=32),
41 | dict(type='DefaultFormatBundle'),
42 | dict(type='Collect', keys=['img', 'gt_labels', 'proposals', 'ss']),
43 | ]
44 | test_pipeline = [
45 | dict(type='LoadImageFromFile'),
46 | dict(type='LoadProposals'),
47 | dict(
48 | type='MultiScaleFlipAug',
49 | img_scale=(688, 2000),
50 | #img_scale=[(500, 2000), (600, 2000), (700, 2000), (800, 2000), (900, 2000)],
51 | flip=False,
52 | transforms=[
53 | dict(type='Resize', keep_ratio=True),
54 | dict(type='RandomFlip'),
55 | dict(type='Normalize', **img_norm_cfg),
56 | dict(type='Pad', size_divisor=32),
57 | dict(type='ImageToTensor', keys=['img']),
58 | dict(type='Collect', keys=['img', 'proposals']),
59 | ])
60 | ]
61 | data = dict(
62 | samples_per_gpu=1,
63 | workers_per_gpu=2,
64 | train=dict(
65 | type=dataset_type,
66 | ann_file=data_root + 'VOC2007/ImageSets/Main/trainval.txt',
67 | img_prefix=data_root + 'VOC2007/',
68 | proposal_file='/datavoc/selective_search_data/voc_2007_trainval.pkl',
69 | pipeline=train_pipeline),
70 | val=dict(
71 | type=dataset_type,
72 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
73 | img_prefix=data_root + 'VOC2007/',
74 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl',
75 | pipeline=test_pipeline),
76 | test=dict(
77 | type=dataset_type,
78 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
79 | img_prefix=data_root + 'VOC2007/',
80 | proposal_file='/datavoc/selective_search_data/voc_2007_test.pkl',
81 | pipeline=test_pipeline))
82 |
83 | work_dir = 'work_dirs/wsod2_vgg16/'
84 |
--------------------------------------------------------------------------------
/demo/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/demo/demo.jpg
--------------------------------------------------------------------------------
/demo/image_demo.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from mmdet.apis import inference_detector, init_detector, show_result_pyplot
4 |
5 |
6 | def main():
7 | parser = ArgumentParser()
8 | parser.add_argument('img', help='Image file')
9 | parser.add_argument('config', help='Config file')
10 | parser.add_argument('checkpoint', help='Checkpoint file')
11 | parser.add_argument(
12 | '--device', default='cuda:0', help='Device used for inference')
13 | parser.add_argument(
14 | '--score-thr', type=float, default=0.3, help='bbox score threshold')
15 | args = parser.parse_args()
16 |
17 | # build the model from a config file and a checkpoint file
18 | model = init_detector(args.config, args.checkpoint, device=args.device)
19 | # test a single image
20 | result = inference_detector(model, args.img)
21 | # show the results
22 | show_result_pyplot(model, args.img, result, score_thr=args.score_thr)
23 |
24 |
25 | if __name__ == '__main__':
26 | main()
27 |
--------------------------------------------------------------------------------
/demo/webcam_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import cv2
4 | import torch
5 |
6 | from mmdet.apis import inference_detector, init_detector
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='MMDetection webcam demo')
11 | parser.add_argument('config', help='test config file path')
12 | parser.add_argument('checkpoint', help='checkpoint file')
13 | parser.add_argument(
14 | '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
15 | parser.add_argument(
16 | '--camera-id', type=int, default=0, help='camera device id')
17 | parser.add_argument(
18 | '--score-thr', type=float, default=0.5, help='bbox score threshold')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | def main():
24 | args = parse_args()
25 |
26 | device = torch.device(args.device)
27 |
28 | model = init_detector(args.config, args.checkpoint, device=device)
29 |
30 | camera = cv2.VideoCapture(args.camera_id)
31 |
32 | print('Press "Esc", "q" or "Q" to exit.')
33 | while True:
34 | ret_val, img = camera.read()
35 | result = inference_detector(model, img)
36 |
37 | ch = cv2.waitKey(1)
38 | if ch == 27 or ch == ord('q') or ch == ord('Q'):
39 | break
40 |
41 | model.show_result(
42 | img, result, score_thr=args.score_thr, wait_time=1, show=True)
43 |
44 |
45 | if __name__ == '__main__':
46 | main()
47 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.6.0"
2 | ARG CUDA="10.1"
3 | ARG CUDNN="7"
4 |
5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 |
7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
10 |
11 | RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
12 | && apt-get clean \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | # Install MMCV
16 | RUN pip install mmcv-full==latest+torch1.6.0+cu101 -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
17 |
18 | # Install MMDetection
19 | RUN conda clean --all
20 | RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
21 | WORKDIR /mmdetection
22 | ENV FORCE_CUDA="1"
23 | RUN pip install -r requirements/build.txt
24 | RUN pip install --no-cache-dir -e .
25 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =================
3 |
4 | mmdet.apis
5 | --------------
6 | .. automodule:: mmdet.apis
7 | :members:
8 |
9 | mmdet.core
10 | --------------
11 |
12 | anchor
13 | ^^^^^^^^^^
14 | .. automodule:: mmdet.core.anchor
15 | :members:
16 |
17 | bbox
18 | ^^^^^^^^^^
19 | .. automodule:: mmdet.core.bbox
20 | :members:
21 |
22 | export
23 | ^^^^^^^^^^
24 | .. automodule:: mmdet.core.export
25 | :members:
26 |
27 | mask
28 | ^^^^^^^^^^
29 | .. automodule:: mmdet.core.mask
30 | :members:
31 |
32 | evaluation
33 | ^^^^^^^^^^
34 | .. automodule:: mmdet.core.evaluation
35 | :members:
36 |
37 | post_processing
38 | ^^^^^^^^^^^^^^^
39 | .. automodule:: mmdet.core.post_processing
40 | :members:
41 |
42 | optimizer
43 | ^^^^^^^^^^
44 | .. automodule:: mmdet.core.optimizer
45 | :members:
46 |
47 | utils
48 | ^^^^^^^^^^
49 | .. automodule:: mmdet.core.utils
50 | :members:
51 |
52 | mmdet.datasets
53 | --------------
54 |
55 | datasets
56 | ^^^^^^^^^^
57 | .. automodule:: mmdet.datasets
58 | :members:
59 |
60 | pipelines
61 | ^^^^^^^^^^
62 | .. automodule:: mmdet.datasets.pipelines
63 | :members:
64 |
65 | mmdet.models
66 | --------------
67 |
68 | detectors
69 | ^^^^^^^^^^
70 | .. automodule:: mmdet.models.detectors
71 | :members:
72 |
73 | backbones
74 | ^^^^^^^^^^
75 | .. automodule:: mmdet.models.backbones
76 | :members:
77 |
78 | necks
79 | ^^^^^^^^^^^^
80 | .. automodule:: mmdet.models.necks
81 | :members:
82 |
83 | dense_heads
84 | ^^^^^^^^^^^^
85 | .. automodule:: mmdet.models.dense_heads
86 | :members:
87 |
88 | roi_heads
89 | ^^^^^^^^^^
90 | .. automodule:: mmdet.models.roi_heads
91 | :members:
92 |
93 | losses
94 | ^^^^^^^^^^
95 | .. automodule:: mmdet.models.losses
96 | :members:
97 |
98 | utils
99 | ^^^^^^^^^^
100 | .. automodule:: mmdet.models.utils
101 | :members:
102 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import subprocess
15 | import sys
16 |
17 | sys.path.insert(0, os.path.abspath('..'))
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = 'MMDetection'
22 | copyright = '2018-2020, OpenMMLab'
23 | author = 'MMDetection Authors'
24 | version_file = '../mmdet/version.py'
25 |
26 |
27 | def get_version():
28 | with open(version_file, 'r') as f:
29 | exec(compile(f.read(), version_file, 'exec'))
30 | return locals()['__version__']
31 |
32 |
33 | # The full version, including alpha/beta/rc tags
34 | release = get_version()
35 |
36 | # -- General configuration ---------------------------------------------------
37 |
38 | # Add any Sphinx extension module names here, as strings. They can be
39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
40 | # ones.
41 | extensions = [
42 | 'sphinx.ext.autodoc',
43 | 'sphinx.ext.napoleon',
44 | 'sphinx.ext.viewcode',
45 | 'recommonmark',
46 | 'sphinx_markdown_tables',
47 | ]
48 |
49 | autodoc_mock_imports = [
50 | 'matplotlib', 'pycocotools', 'terminaltables', 'mmdet.version', 'mmcv.ops'
51 | ]
52 |
53 | # Add any paths that contain templates here, relative to this directory.
54 | templates_path = ['_templates']
55 |
56 | # The suffix(es) of source filenames.
57 | # You can specify multiple suffix as a list of string:
58 | #
59 | source_suffix = {
60 | '.rst': 'restructuredtext',
61 | '.md': 'markdown',
62 | }
63 |
64 | # The master toctree document.
65 | master_doc = 'index'
66 |
67 | # List of patterns, relative to source directory, that match files and
68 | # directories to ignore when looking for source files.
69 | # This pattern also affects html_static_path and html_extra_path.
70 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
71 |
72 | # -- Options for HTML output -------------------------------------------------
73 |
74 | # The theme to use for HTML and HTML Help pages. See the documentation for
75 | # a list of builtin themes.
76 | #
77 | html_theme = 'sphinx_rtd_theme'
78 |
79 | # Add any paths that contain custom static files (such as style sheets) here,
80 | # relative to this directory. They are copied after the builtin static files,
81 | # so a file named "default.css" will overwrite the builtin "default.css".
82 | html_static_path = ['_static']
83 |
84 |
85 | def builder_inited_handler(app):
86 | subprocess.run(['./stat.py'])
87 |
88 |
89 | def setup(app):
90 | app.connect('builder-inited', builder_inited_handler)
91 |
--------------------------------------------------------------------------------
/docs/conventions.md:
--------------------------------------------------------------------------------
1 | # Conventions
2 |
3 | Please check the following conventions if you would like to modify MMDetection as your own project.
4 |
5 | ## Loss
6 | In MMDetection, a `dict` containing losses and metrics will be returned by `model(**data)`.
7 |
8 | For example, in bbox head,
9 | ```python
10 | class BBoxHead(nn.Module):
11 | ...
12 | def loss(self, ...):
13 | losses = dict()
14 | # classification loss
15 | losses['loss_cls'] = self.loss_cls(...)
16 | # classification accuracy
17 | losses['acc'] = accuracy(...)
18 | # bbox regression loss
19 | losses['loss_bbox'] = self.loss_bbox(...)
20 | return losses
21 | ```
22 | `bbox_head.loss()` will be called during model forward.
23 | The returned dict contains `'loss_bbox'`, `'loss_cls'`, `'acc'` .
24 | Only `'loss_bbox'`, `'loss_cls'` will be used during back propagation,
25 | `'acc'` will only be used as a metric to monitor training process.
26 |
27 | By default, only values whose keys contain `'loss'` will be back propagated.
28 | This behavior could be changed by modifying `BaseDetector.train_step()`.
29 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to MMDetection's documentation!
2 | =======================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Get Started
7 |
8 | get_started.md
9 | modelzoo_statistics.md
10 | model_zoo.md
11 |
12 | .. toctree::
13 | :maxdepth: 2
14 | :caption: Quick Run
15 |
16 | 1_exist_data_model.md
17 | 2_new_data_model.md
18 |
19 | .. toctree::
20 | :maxdepth: 2
21 | :caption: Tutorials
22 |
23 | tutorials/index.rst
24 |
25 | .. toctree::
26 | :maxdepth: 2
27 | :caption: Useful Tools and Scripts
28 |
29 | useful_tools.md
30 |
31 | .. toctree::
32 | :maxdepth: 2
33 | :caption: Notes
34 |
35 | conventions.md
36 | compatibility.md
37 | projects.md
38 | changelog.md
39 | faq.md
40 |
41 | .. toctree::
42 | :caption: API Reference
43 |
44 | api.rst
45 |
46 | Indices and tables
47 | ==================
48 |
49 | * :ref:`genindex`
50 | * :ref:`search`
51 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/stat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import glob
3 | import os.path as osp
4 | import re
5 |
6 | url_prefix = 'https://github.com/open-mmlab/mmdetection/blob/master/'
7 |
8 | files = sorted(glob.glob('../configs/*/README.md'))
9 |
10 | stats = []
11 | titles = []
12 | num_ckpts = 0
13 |
14 | for f in files:
15 | url = osp.dirname(f.replace('../', url_prefix))
16 |
17 | with open(f, 'r') as content_file:
18 | content = content_file.read()
19 |
20 | title = content.split('\n')[0].replace('# ', '')
21 |
22 | titles.append(title)
23 | ckpts = set(x.lower().strip()
24 | for x in re.findall(r'https?://download.*\.pth', content)
25 | if 'mmdetection' in x)
26 | num_ckpts += len(ckpts)
27 | statsmsg = f"""
28 | \t* [{title}]({url}) ({len(ckpts)} ckpts)
29 | """
30 | stats.append((title, ckpts, statsmsg))
31 |
32 | msglist = '\n'.join(x for _, _, x in stats)
33 |
34 | modelzoo = f"""
35 | # Model Zoo Statistics
36 |
37 | * Number of papers: {len(titles)}
38 | * Number of checkpoints: {num_ckpts}
39 | {msglist}
40 | """
41 |
42 | with open('modelzoo_statistics.md', 'w') as f:
43 | f.write(modelzoo)
44 |
--------------------------------------------------------------------------------
/docs/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | .. toctree::
2 | :maxdepth: 2
3 |
4 | config.md
5 | customize_dataset.md
6 | data_pipeline.md
7 | customize_models.md
8 | customize_runtime.md
9 | customize_losses.md
10 | finetune.md
11 |
--------------------------------------------------------------------------------
/mmdet/__init__.py:
--------------------------------------------------------------------------------
1 | import mmcv
2 |
3 | from .version import __version__, short_version
4 |
5 |
6 | def digit_version(version_str):
7 | digit_version = []
8 | for x in version_str.split('.'):
9 | if x.isdigit():
10 | digit_version.append(int(x))
11 | elif x.find('rc') != -1:
12 | patch_version = x.split('rc')
13 | digit_version.append(int(patch_version[0]) - 1)
14 | digit_version.append(int(patch_version[1]))
15 | return digit_version
16 |
17 |
18 | mmcv_minimum_version = '1.1.5'
19 | mmcv_maximum_version = '1.3'
20 | mmcv_version = digit_version(mmcv.__version__)
21 |
22 |
23 | assert (mmcv_version >= digit_version(mmcv_minimum_version)
24 | and mmcv_version <= digit_version(mmcv_maximum_version)), \
25 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
26 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
27 |
28 | __all__ = ['__version__', 'short_version']
29 |
--------------------------------------------------------------------------------
/mmdet/apis/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference import (async_inference_detector, inference_detector,
2 | init_detector, show_result_pyplot)
3 | from .test import multi_gpu_test, single_gpu_test
4 | from .train import get_root_logger, set_random_seed, train_detector
5 |
6 | __all__ = [
7 | 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
8 | 'async_inference_detector', 'inference_detector', 'show_result_pyplot',
9 | 'multi_gpu_test', 'single_gpu_test'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmdet/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .anchor import * # noqa: F401, F403
2 | from .bbox import * # noqa: F401, F403
3 | from .evaluation import * # noqa: F401, F403
4 | from .export import * # noqa: F401, F403
5 | from .fp16 import * # noqa: F401, F403
6 | from .mask import * # noqa: F401, F403
7 | from .post_processing import * # noqa: F401, F403
8 | from .utils import * # noqa: F401, F403
9 |
--------------------------------------------------------------------------------
/mmdet/core/anchor/__init__.py:
--------------------------------------------------------------------------------
1 | from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
2 | YOLOAnchorGenerator)
3 | from .builder import ANCHOR_GENERATORS, build_anchor_generator
4 | from .point_generator import PointGenerator
5 | from .utils import anchor_inside_flags, calc_region, images_to_levels
6 |
7 | __all__ = [
8 | 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
9 | 'PointGenerator', 'images_to_levels', 'calc_region',
10 | 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmdet/core/anchor/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import Registry, build_from_cfg
2 |
3 | ANCHOR_GENERATORS = Registry('Anchor generator')
4 |
5 |
6 | def build_anchor_generator(cfg, default_args=None):
7 | return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
8 |
--------------------------------------------------------------------------------
/mmdet/core/anchor/point_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .builder import ANCHOR_GENERATORS
4 |
5 |
6 | @ANCHOR_GENERATORS.register_module()
7 | class PointGenerator(object):
8 |
9 | def _meshgrid(self, x, y, row_major=True):
10 | xx = x.repeat(len(y))
11 | yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
12 | if row_major:
13 | return xx, yy
14 | else:
15 | return yy, xx
16 |
17 | def grid_points(self, featmap_size, stride=16, device='cuda'):
18 | feat_h, feat_w = featmap_size
19 | shift_x = torch.arange(0., feat_w, device=device) * stride
20 | shift_y = torch.arange(0., feat_h, device=device) * stride
21 | shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
22 | stride = shift_x.new_full((shift_xx.shape[0], ), stride)
23 | shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
24 | all_points = shifts.to(device)
25 | return all_points
26 |
27 | def valid_flags(self, featmap_size, valid_size, device='cuda'):
28 | feat_h, feat_w = featmap_size
29 | valid_h, valid_w = valid_size
30 | assert valid_h <= feat_h and valid_w <= feat_w
31 | valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
32 | valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
33 | valid_x[:valid_w] = 1
34 | valid_y[:valid_h] = 1
35 | valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
36 | valid = valid_xx & valid_yy
37 | return valid
38 |
--------------------------------------------------------------------------------
/mmdet/core/anchor/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def images_to_levels(target, num_levels):
5 | """Convert targets by image to targets by feature level.
6 |
7 | [target_img0, target_img1] -> [target_level0, target_level1, ...]
8 | """
9 | target = torch.stack(target, 0)
10 | level_targets = []
11 | start = 0
12 | for n in num_levels:
13 | end = start + n
14 | # level_targets.append(target[:, start:end].squeeze(0))
15 | level_targets.append(target[:, start:end])
16 | start = end
17 | return level_targets
18 |
19 |
20 | def anchor_inside_flags(flat_anchors,
21 | valid_flags,
22 | img_shape,
23 | allowed_border=0):
24 | """Check whether the anchors are inside the border.
25 |
26 | Args:
27 | flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
28 | valid_flags (torch.Tensor): An existing valid flags of anchors.
29 | img_shape (tuple(int)): Shape of current image.
30 | allowed_border (int, optional): The border to allow the valid anchor.
31 | Defaults to 0.
32 |
33 | Returns:
34 | torch.Tensor: Flags indicating whether the anchors are inside a \
35 | valid range.
36 | """
37 | img_h, img_w = img_shape[:2]
38 | if allowed_border >= 0:
39 | inside_flags = valid_flags & \
40 | (flat_anchors[:, 0] >= -allowed_border) & \
41 | (flat_anchors[:, 1] >= -allowed_border) & \
42 | (flat_anchors[:, 2] < img_w + allowed_border) & \
43 | (flat_anchors[:, 3] < img_h + allowed_border)
44 | else:
45 | inside_flags = valid_flags
46 | return inside_flags
47 |
48 |
49 | def calc_region(bbox, ratio, featmap_size=None):
50 | """Calculate a proportional bbox region.
51 |
52 | The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
53 |
54 | Args:
55 | bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
56 | ratio (float): Ratio of the output region.
57 | featmap_size (tuple): Feature map size used for clipping the boundary.
58 |
59 | Returns:
60 | tuple: x1, y1, x2, y2
61 | """
62 | x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
63 | y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
64 | x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
65 | y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
66 | if featmap_size is not None:
67 | x1 = x1.clamp(min=0, max=featmap_size[1])
68 | y1 = y1.clamp(min=0, max=featmap_size[0])
69 | x2 = x2.clamp(min=0, max=featmap_size[1])
70 | y2 = y2.clamp(min=0, max=featmap_size[0])
71 | return (x1, y1, x2, y2)
72 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/__init__.py:
--------------------------------------------------------------------------------
1 | from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
2 | MaxIoUAssigner)
3 | from .builder import build_assigner, build_bbox_coder, build_sampler
4 | from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
5 | TBLRBBoxCoder)
6 | from .iou_calculators import BboxOverlaps2D, bbox_overlaps
7 | from .samplers import (BaseSampler, CombinedSampler,
8 | InstanceBalancedPosSampler, IoUBalancedNegSampler,
9 | OHEMSampler, PseudoSampler, RandomSampler,
10 | SamplingResult, ScoreHLRSampler)
11 | from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip,
12 | bbox_mapping, bbox_mapping_back, bbox_rescale,
13 | distance2bbox, roi2bbox)
14 |
15 | __all__ = [
16 | 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
17 | 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
18 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
19 | 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
20 | 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
21 | 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
22 | 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
23 | 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
24 | 'bbox_rescale'
25 | ]
26 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/assigners/__init__.py:
--------------------------------------------------------------------------------
1 | from .approx_max_iou_assigner import ApproxMaxIoUAssigner
2 | from .assign_result import AssignResult
3 | from .atss_assigner import ATSSAssigner
4 | from .base_assigner import BaseAssigner
5 | from .center_region_assigner import CenterRegionAssigner
6 | from .grid_assigner import GridAssigner
7 | from .max_iou_assigner import MaxIoUAssigner
8 | from .point_assigner import PointAssigner
9 |
10 | __all__ = [
11 | 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
12 | 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner'
13 | ]
14 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/assigners/base_assigner.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 |
4 | class BaseAssigner(metaclass=ABCMeta):
5 | """Base assigner that assigns boxes to ground truth boxes."""
6 |
7 | @abstractmethod
8 | def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
9 | """Assign boxes to either a ground truth boxe or a negative boxes."""
10 | pass
11 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import Registry, build_from_cfg
2 |
3 | BBOX_ASSIGNERS = Registry('bbox_assigner')
4 | BBOX_SAMPLERS = Registry('bbox_sampler')
5 | BBOX_CODERS = Registry('bbox_coder')
6 |
7 |
8 | def build_assigner(cfg, **default_args):
9 | """Builder of box assigner."""
10 | return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
11 |
12 |
13 | def build_sampler(cfg, **default_args):
14 | """Builder of box sampler."""
15 | return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
16 |
17 |
18 | def build_bbox_coder(cfg, **default_args):
19 | """Builder of box coder."""
20 | return build_from_cfg(cfg, BBOX_CODERS, default_args)
21 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/coder/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_bbox_coder import BaseBBoxCoder
2 | from .bucketing_bbox_coder import BucketingBBoxCoder
3 | from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
4 | from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
5 | from .pseudo_bbox_coder import PseudoBBoxCoder
6 | from .tblr_bbox_coder import TBLRBBoxCoder
7 | from .yolo_bbox_coder import YOLOBBoxCoder
8 |
9 | __all__ = [
10 | 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
11 | 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
12 | 'BucketingBBoxCoder'
13 | ]
14 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/coder/base_bbox_coder.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 |
4 | class BaseBBoxCoder(metaclass=ABCMeta):
5 | """Base bounding box coder."""
6 |
7 | def __init__(self, **kwargs):
8 | pass
9 |
10 | @abstractmethod
11 | def encode(self, bboxes, gt_bboxes):
12 | """Encode deltas between bboxes and ground truth boxes."""
13 | pass
14 |
15 | @abstractmethod
16 | def decode(self, bboxes, bboxes_pred):
17 | """Decode the predicted bboxes according to prediction and base
18 | boxes."""
19 | pass
20 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/coder/pseudo_bbox_coder.py:
--------------------------------------------------------------------------------
1 | from ..builder import BBOX_CODERS
2 | from .base_bbox_coder import BaseBBoxCoder
3 |
4 |
5 | @BBOX_CODERS.register_module()
6 | class PseudoBBoxCoder(BaseBBoxCoder):
7 | """Pseudo bounding box coder."""
8 |
9 | def __init__(self, **kwargs):
10 | super(BaseBBoxCoder, self).__init__(**kwargs)
11 |
12 | def encode(self, bboxes, gt_bboxes):
13 | """torch.Tensor: return the given ``bboxes``"""
14 | return gt_bboxes
15 |
16 | def decode(self, bboxes, pred_bboxes):
17 | """torch.Tensor: return the given ``pred_bboxes``"""
18 | return pred_bboxes
19 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/demodata.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def ensure_rng(rng=None):
6 | """Simple version of the ``kwarray.ensure_rng``
7 |
8 | Args:
9 | rng (int | numpy.random.RandomState | None):
10 | if None, then defaults to the global rng. Otherwise this can be an
11 | integer or a RandomState class
12 | Returns:
13 | (numpy.random.RandomState) : rng -
14 | a numpy random number generator
15 |
16 | References:
17 | https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270
18 | """
19 |
20 | if rng is None:
21 | rng = np.random.mtrand._rand
22 | elif isinstance(rng, int):
23 | rng = np.random.RandomState(rng)
24 | else:
25 | rng = rng
26 | return rng
27 |
28 |
29 | def random_boxes(num=1, scale=1, rng=None):
30 | """Simple version of ``kwimage.Boxes.random``
31 |
32 | Returns:
33 | Tensor: shape (n, 4) in x1, y1, x2, y2 format.
34 |
35 | References:
36 | https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
37 |
38 | Example:
39 | >>> num = 3
40 | >>> scale = 512
41 | >>> rng = 0
42 | >>> boxes = random_boxes(num, scale, rng)
43 | >>> print(boxes)
44 | tensor([[280.9925, 278.9802, 308.6148, 366.1769],
45 | [216.9113, 330.6978, 224.0446, 456.5878],
46 | [405.3632, 196.3221, 493.3953, 270.7942]])
47 | """
48 | rng = ensure_rng(rng)
49 |
50 | tlbr = rng.rand(num, 4).astype(np.float32)
51 |
52 | tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
53 | tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
54 | br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
55 | br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
56 |
57 | tlbr[:, 0] = tl_x * scale
58 | tlbr[:, 1] = tl_y * scale
59 | tlbr[:, 2] = br_x * scale
60 | tlbr[:, 3] = br_y * scale
61 |
62 | boxes = torch.from_numpy(tlbr)
63 | return boxes
64 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/iou_calculators/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_iou_calculator
2 | from .iou2d_calculator import BboxOverlaps2D, bbox_overlaps
3 |
4 | __all__ = ['build_iou_calculator', 'BboxOverlaps2D', 'bbox_overlaps']
5 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/iou_calculators/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import Registry, build_from_cfg
2 |
3 | IOU_CALCULATORS = Registry('IoU calculator')
4 |
5 |
6 | def build_iou_calculator(cfg, default_args=None):
7 | """Builder of IoU calculator."""
8 | return build_from_cfg(cfg, IOU_CALCULATORS, default_args)
9 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_sampler import BaseSampler
2 | from .combined_sampler import CombinedSampler
3 | from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
4 | from .iou_balanced_neg_sampler import IoUBalancedNegSampler
5 | from .ohem_sampler import OHEMSampler
6 | from .pseudo_sampler import PseudoSampler
7 | from .random_sampler import RandomSampler
8 | from .sampling_result import SamplingResult
9 | from .score_hlr_sampler import ScoreHLRSampler
10 |
11 | __all__ = [
12 | 'BaseSampler', 'PseudoSampler', 'RandomSampler',
13 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
14 | 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/samplers/combined_sampler.py:
--------------------------------------------------------------------------------
1 | from ..builder import BBOX_SAMPLERS, build_sampler
2 | from .base_sampler import BaseSampler
3 |
4 |
5 | @BBOX_SAMPLERS.register_module()
6 | class CombinedSampler(BaseSampler):
7 | """A sampler that combines positive sampler and negative sampler."""
8 |
9 | def __init__(self, pos_sampler, neg_sampler, **kwargs):
10 | super(CombinedSampler, self).__init__(**kwargs)
11 | self.pos_sampler = build_sampler(pos_sampler, **kwargs)
12 | self.neg_sampler = build_sampler(neg_sampler, **kwargs)
13 |
14 | def _sample_pos(self, **kwargs):
15 | """Sample positive samples."""
16 | raise NotImplementedError
17 |
18 | def _sample_neg(self, **kwargs):
19 | """Sample negative samples."""
20 | raise NotImplementedError
21 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from ..builder import BBOX_SAMPLERS
5 | from .random_sampler import RandomSampler
6 |
7 |
8 | @BBOX_SAMPLERS.register_module()
9 | class InstanceBalancedPosSampler(RandomSampler):
10 | """Instance balanced sampler that samples equal number of positive samples
11 | for each instance."""
12 |
13 | def _sample_pos(self, assign_result, num_expected, **kwargs):
14 | """Sample positive boxes.
15 |
16 | Args:
17 | assign_result (:obj:`AssignResult`): The assigned results of boxes.
18 | num_expected (int): The number of expected positive samples
19 |
20 | Returns:
21 | Tensor or ndarray: sampled indices.
22 | """
23 | pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
24 | if pos_inds.numel() != 0:
25 | pos_inds = pos_inds.squeeze(1)
26 | if pos_inds.numel() <= num_expected:
27 | return pos_inds
28 | else:
29 | unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
30 | num_gts = len(unique_gt_inds)
31 | num_per_gt = int(round(num_expected / float(num_gts)) + 1)
32 | sampled_inds = []
33 | for i in unique_gt_inds:
34 | inds = torch.nonzero(
35 | assign_result.gt_inds == i.item(), as_tuple=False)
36 | if inds.numel() != 0:
37 | inds = inds.squeeze(1)
38 | else:
39 | continue
40 | if len(inds) > num_per_gt:
41 | inds = self.random_choice(inds, num_per_gt)
42 | sampled_inds.append(inds)
43 | sampled_inds = torch.cat(sampled_inds)
44 | if len(sampled_inds) < num_expected:
45 | num_extra = num_expected - len(sampled_inds)
46 | extra_inds = np.array(
47 | list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
48 | if len(extra_inds) > num_extra:
49 | extra_inds = self.random_choice(extra_inds, num_extra)
50 | extra_inds = torch.from_numpy(extra_inds).to(
51 | assign_result.gt_inds.device).long()
52 | sampled_inds = torch.cat([sampled_inds, extra_inds])
53 | elif len(sampled_inds) > num_expected:
54 | sampled_inds = self.random_choice(sampled_inds, num_expected)
55 | return sampled_inds
56 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/samplers/pseudo_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..builder import BBOX_SAMPLERS
4 | from .base_sampler import BaseSampler
5 | from .sampling_result import SamplingResult
6 |
7 |
8 | @BBOX_SAMPLERS.register_module()
9 | class PseudoSampler(BaseSampler):
10 | """A pseudo sampler that does not do sampling actually."""
11 |
12 | def __init__(self, **kwargs):
13 | pass
14 |
15 | def _sample_pos(self, **kwargs):
16 | """Sample positive samples."""
17 | raise NotImplementedError
18 |
19 | def _sample_neg(self, **kwargs):
20 | """Sample negative samples."""
21 | raise NotImplementedError
22 |
23 | def sample(self, assign_result, bboxes, gt_bboxes, **kwargs):
24 | """Directly returns the positive and negative indices of samples.
25 |
26 | Args:
27 | assign_result (:obj:`AssignResult`): Assigned results
28 | bboxes (torch.Tensor): Bounding boxes
29 | gt_bboxes (torch.Tensor): Ground truth boxes
30 |
31 | Returns:
32 | :obj:`SamplingResult`: sampler results
33 | """
34 | pos_inds = torch.nonzero(
35 | assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
36 | neg_inds = torch.nonzero(
37 | assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
38 | gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
39 | sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
40 | assign_result, gt_flags)
41 | return sampling_result
42 |
--------------------------------------------------------------------------------
/mmdet/core/bbox/samplers/random_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..builder import BBOX_SAMPLERS
4 | from .base_sampler import BaseSampler
5 |
6 |
7 | @BBOX_SAMPLERS.register_module()
8 | class RandomSampler(BaseSampler):
9 | """Random sampler.
10 |
11 | Args:
12 | num (int): Number of samples
13 | pos_fraction (float): Fraction of positive samples
14 | neg_pos_up (int, optional): Upper bound number of negative and
15 | positive samples. Defaults to -1.
16 | add_gt_as_proposals (bool, optional): Whether to add ground truth
17 | boxes as proposals. Defaults to True.
18 | """
19 |
20 | def __init__(self,
21 | num,
22 | pos_fraction,
23 | neg_pos_ub=-1,
24 | add_gt_as_proposals=True,
25 | **kwargs):
26 | from mmdet.core.bbox import demodata
27 | super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
28 | add_gt_as_proposals)
29 | self.rng = demodata.ensure_rng(kwargs.get('rng', None))
30 |
31 | def random_choice(self, gallery, num):
32 | """Random select some elements from the gallery.
33 |
34 | If `gallery` is a Tensor, the returned indices will be a Tensor;
35 | If `gallery` is a ndarray or list, the returned indices will be a
36 | ndarray.
37 |
38 | Args:
39 | gallery (Tensor | ndarray | list): indices pool.
40 | num (int): expected sample num.
41 |
42 | Returns:
43 | Tensor or ndarray: sampled indices.
44 | """
45 | assert len(gallery) >= num
46 |
47 | is_tensor = isinstance(gallery, torch.Tensor)
48 | if not is_tensor:
49 | if torch.cuda.is_available():
50 | device = torch.cuda.current_device()
51 | else:
52 | device = 'cpu'
53 | gallery = torch.tensor(gallery, dtype=torch.long, device=device)
54 | perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
55 | rand_inds = gallery[perm]
56 | if not is_tensor:
57 | rand_inds = rand_inds.cpu().numpy()
58 | return rand_inds
59 |
60 | def _sample_pos(self, assign_result, num_expected, **kwargs):
61 | """Randomly sample some positive samples."""
62 | pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
63 | if pos_inds.numel() != 0:
64 | pos_inds = pos_inds.squeeze(1)
65 | if pos_inds.numel() <= num_expected:
66 | return pos_inds
67 | else:
68 | return self.random_choice(pos_inds, num_expected)
69 |
70 | def _sample_neg(self, assign_result, num_expected, **kwargs):
71 | """Randomly sample some negative samples."""
72 | neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
73 | if neg_inds.numel() != 0:
74 | neg_inds = neg_inds.squeeze(1)
75 | if len(neg_inds) <= num_expected:
76 | return neg_inds
77 | else:
78 | return self.random_choice(neg_inds, num_expected)
79 |
--------------------------------------------------------------------------------
/mmdet/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
2 | get_classes, imagenet_det_classes,
3 | imagenet_vid_classes, voc_classes)
4 | from .eval_hooks import DistEvalHook, EvalHook
5 | from .mean_ap import average_precision, eval_map, print_map_summary
6 | from .recall import (eval_recalls, plot_iou_recall, plot_num_recall,
7 | print_recall_summary)
8 |
9 | __all__ = [
10 | 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
11 | 'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes',
12 | 'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map',
13 | 'print_map_summary', 'eval_recalls', 'print_recall_summary',
14 | 'plot_num_recall', 'plot_iou_recall'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmdet/core/evaluation/bbox_overlaps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6):
5 | """Calculate the ious between each bbox of bboxes1 and bboxes2.
6 |
7 | Args:
8 | bboxes1(ndarray): shape (n, 4)
9 | bboxes2(ndarray): shape (k, 4)
10 | mode(str): iou (intersection over union) or iof (intersection
11 | over foreground)
12 |
13 | Returns:
14 | ious(ndarray): shape (n, k)
15 | """
16 |
17 | assert mode in ['iou', 'iof']
18 |
19 | bboxes1 = bboxes1.astype(np.float32)
20 | bboxes2 = bboxes2.astype(np.float32)
21 | rows = bboxes1.shape[0]
22 | cols = bboxes2.shape[0]
23 | ious = np.zeros((rows, cols), dtype=np.float32)
24 | if rows * cols == 0:
25 | return ious
26 | exchange = False
27 | if bboxes1.shape[0] > bboxes2.shape[0]:
28 | bboxes1, bboxes2 = bboxes2, bboxes1
29 | ious = np.zeros((cols, rows), dtype=np.float32)
30 | exchange = True
31 | area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
32 | area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
33 | for i in range(bboxes1.shape[0]):
34 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
35 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
36 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
37 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
38 | overlap = np.maximum(x_end - x_start, 0) * np.maximum(
39 | y_end - y_start, 0)
40 | if mode == 'iou':
41 | union = area1[i] + area2 - overlap
42 | else:
43 | union = area1[i] if not exchange else area2
44 | union = np.maximum(union, eps)
45 | ious[i, :] = overlap / union
46 | if exchange:
47 | ious = ious.T
48 | return ious
49 |
--------------------------------------------------------------------------------
/mmdet/core/export/__init__.py:
--------------------------------------------------------------------------------
1 | from .pytorch2onnx import (build_model_from_cfg,
2 | generate_inputs_and_wrap_model,
3 | preprocess_example_input)
4 |
5 | __all__ = [
6 | 'build_model_from_cfg', 'generate_inputs_and_wrap_model',
7 | 'preprocess_example_input'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmdet/core/fp16/__init__.py:
--------------------------------------------------------------------------------
1 | from .deprecated_fp16_utils import \
2 | DeprecatedFp16OptimizerHook as Fp16OptimizerHook
3 | from .deprecated_fp16_utils import deprecated_auto_fp16 as auto_fp16
4 | from .deprecated_fp16_utils import deprecated_force_fp32 as force_fp32
5 | from .deprecated_fp16_utils import \
6 | deprecated_wrap_fp16_model as wrap_fp16_model
7 |
8 | __all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model']
9 |
--------------------------------------------------------------------------------
/mmdet/core/fp16/deprecated_fp16_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from mmcv.runner import (Fp16OptimizerHook, auto_fp16, force_fp32,
4 | wrap_fp16_model)
5 |
6 |
7 | class DeprecatedFp16OptimizerHook(Fp16OptimizerHook):
8 | """A wrapper class for the FP16 optimizer hook. This class wraps
9 | :class:`Fp16OptimizerHook` in `mmcv.runner` and shows a warning that the
10 | :class:`Fp16OptimizerHook` from `mmdet.core` will be deprecated.
11 |
12 | Refer to :class:`Fp16OptimizerHook` in `mmcv.runner` for more details.
13 |
14 | Args:
15 | loss_scale (float): Scale factor multiplied with loss.
16 | """
17 |
18 | def __init__(*args, **kwargs):
19 | super().__init__(*args, **kwargs)
20 | warnings.warn(
21 | 'Importing Fp16OptimizerHook from "mmdet.core" will be '
22 | 'deprecated in the future. Please import them from "mmcv.runner" '
23 | 'instead')
24 |
25 |
26 | def deprecated_auto_fp16(*args, **kwargs):
27 | warnings.warn(
28 | 'Importing auto_fp16 from "mmdet.core" will be '
29 | 'deprecated in the future. Please import them from "mmcv.runner" '
30 | 'instead')
31 | return auto_fp16(*args, **kwargs)
32 |
33 |
34 | def deprecated_force_fp32(*args, **kwargs):
35 | warnings.warn(
36 | 'Importing force_fp32 from "mmdet.core" will be '
37 | 'deprecated in the future. Please import them from "mmcv.runner" '
38 | 'instead')
39 | return force_fp32(*args, **kwargs)
40 |
41 |
42 | def deprecated_wrap_fp16_model(*args, **kwargs):
43 | warnings.warn(
44 | 'Importing wrap_fp16_model from "mmdet.core" will be '
45 | 'deprecated in the future. Please import them from "mmcv.runner" '
46 | 'instead')
47 | wrap_fp16_model(*args, **kwargs)
48 |
--------------------------------------------------------------------------------
/mmdet/core/mask/__init__.py:
--------------------------------------------------------------------------------
1 | from .mask_target import mask_target
2 | from .structures import BaseInstanceMasks, BitmapMasks, PolygonMasks
3 | from .utils import encode_mask_results, split_combined_polys
4 |
5 | __all__ = [
6 | 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks',
7 | 'PolygonMasks', 'encode_mask_results'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmdet/core/mask/mask_target.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.nn.modules.utils import _pair
4 |
5 |
6 | def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
7 | cfg):
8 | """Compute mask target for positive proposals in multiple images.
9 |
10 | Args:
11 | pos_proposals_list (list[Tensor]): Positive proposals in multiple
12 | images.
13 | pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each
14 | positive proposals.
15 | gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of
16 | each image.
17 | cfg (dict): Config dict that specifies the mask size.
18 |
19 | Returns:
20 | list[Tensor]: Mask target of each image.
21 | """
22 | cfg_list = [cfg for _ in range(len(pos_proposals_list))]
23 | mask_targets = map(mask_target_single, pos_proposals_list,
24 | pos_assigned_gt_inds_list, gt_masks_list, cfg_list)
25 | mask_targets = list(mask_targets)
26 | if len(mask_targets) > 0:
27 | mask_targets = torch.cat(mask_targets)
28 | return mask_targets
29 |
30 |
31 | def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
32 | """Compute mask target for each positive proposal in the image.
33 |
34 | Args:
35 | pos_proposals (Tensor): Positive proposals.
36 | pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals.
37 | gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap
38 | or Polygon.
39 | cfg (dict): Config dict that indicate the mask size.
40 |
41 | Returns:
42 | Tensor: Mask target of each positive proposals in the image.
43 | """
44 | device = pos_proposals.device
45 | mask_size = _pair(cfg.mask_size)
46 | num_pos = pos_proposals.size(0)
47 | if num_pos > 0:
48 | proposals_np = pos_proposals.cpu().numpy()
49 | maxh, maxw = gt_masks.height, gt_masks.width
50 | proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw)
51 | proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh)
52 | pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
53 |
54 | mask_targets = gt_masks.crop_and_resize(
55 | proposals_np, mask_size, device=device,
56 | inds=pos_assigned_gt_inds).to_ndarray()
57 |
58 | mask_targets = torch.from_numpy(mask_targets).float().to(device)
59 | else:
60 | mask_targets = pos_proposals.new_zeros((0, ) + mask_size)
61 |
62 | return mask_targets
63 |
--------------------------------------------------------------------------------
/mmdet/core/mask/utils.py:
--------------------------------------------------------------------------------
1 | import mmcv
2 | import numpy as np
3 | import pycocotools.mask as mask_util
4 |
5 |
6 | def split_combined_polys(polys, poly_lens, polys_per_mask):
7 | """Split the combined 1-D polys into masks.
8 |
9 | A mask is represented as a list of polys, and a poly is represented as
10 | a 1-D array. In dataset, all masks are concatenated into a single 1-D
11 | tensor. Here we need to split the tensor into original representations.
12 |
13 | Args:
14 | polys (list): a list (length = image num) of 1-D tensors
15 | poly_lens (list): a list (length = image num) of poly length
16 | polys_per_mask (list): a list (length = image num) of poly number
17 | of each mask
18 |
19 | Returns:
20 | list: a list (length = image num) of list (length = mask num) of \
21 | list (length = poly num) of numpy array.
22 | """
23 | mask_polys_list = []
24 | for img_id in range(len(polys)):
25 | polys_single = polys[img_id]
26 | polys_lens_single = poly_lens[img_id].tolist()
27 | polys_per_mask_single = polys_per_mask[img_id].tolist()
28 |
29 | split_polys = mmcv.slice_list(polys_single, polys_lens_single)
30 | mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single)
31 | mask_polys_list.append(mask_polys)
32 | return mask_polys_list
33 |
34 |
35 | # TODO: move this function to more proper place
36 | def encode_mask_results(mask_results):
37 | """Encode bitmap mask to RLE code.
38 |
39 | Args:
40 | mask_results (list | tuple[list]): bitmap mask results.
41 | In mask scoring rcnn, mask_results is a tuple of (segm_results,
42 | segm_cls_score).
43 |
44 | Returns:
45 | list | tuple: RLE encoded mask.
46 | """
47 | if isinstance(mask_results, tuple): # mask scoring
48 | cls_segms, cls_mask_scores = mask_results
49 | else:
50 | cls_segms = mask_results
51 | num_classes = len(cls_segms)
52 | encoded_mask_results = [[] for _ in range(num_classes)]
53 | for i in range(len(cls_segms)):
54 | for cls_segm in cls_segms[i]:
55 | encoded_mask_results[i].append(
56 | mask_util.encode(
57 | np.array(
58 | cls_segm[:, :, np.newaxis], order='F',
59 | dtype='uint8'))[0]) # encoded with RLE
60 | if isinstance(mask_results, tuple):
61 | return encoded_mask_results, cls_mask_scores
62 | else:
63 | return encoded_mask_results
64 |
--------------------------------------------------------------------------------
/mmdet/core/post_processing/__init__.py:
--------------------------------------------------------------------------------
1 | from .bbox_nms import fast_nms, multiclass_nms, WeaklyMulticlassNMS
2 | from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
3 | merge_aug_proposals, merge_aug_scores)
4 |
5 | __all__ = [
6 | 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
7 | 'merge_aug_scores', 'merge_aug_masks', 'fast_nms', 'WeaklyMulticlassNMS'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmdet/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dist_utils import DistOptimizerHook, allreduce_grads, reduce_mean
2 | from .misc import multi_apply, unmap
3 |
4 | __all__ = [
5 | 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
6 | 'unmap'
7 | ]
8 |
--------------------------------------------------------------------------------
/mmdet/core/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import OrderedDict
3 |
4 | import torch.distributed as dist
5 | from mmcv.runner import OptimizerHook
6 | from torch._utils import (_flatten_dense_tensors, _take_tensors,
7 | _unflatten_dense_tensors)
8 |
9 |
10 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
11 | if bucket_size_mb > 0:
12 | bucket_size_bytes = bucket_size_mb * 1024 * 1024
13 | buckets = _take_tensors(tensors, bucket_size_bytes)
14 | else:
15 | buckets = OrderedDict()
16 | for tensor in tensors:
17 | tp = tensor.type()
18 | if tp not in buckets:
19 | buckets[tp] = []
20 | buckets[tp].append(tensor)
21 | buckets = buckets.values()
22 |
23 | for bucket in buckets:
24 | flat_tensors = _flatten_dense_tensors(bucket)
25 | dist.all_reduce(flat_tensors)
26 | flat_tensors.div_(world_size)
27 | for tensor, synced in zip(
28 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
29 | tensor.copy_(synced)
30 |
31 |
32 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
33 | """Allreduce gradients.
34 |
35 | Args:
36 | params (list[torch.Parameters]): List of parameters of a model
37 | coalesce (bool, optional): Whether allreduce parameters as a whole.
38 | Defaults to True.
39 | bucket_size_mb (int, optional): Size of bucket, the unit is MB.
40 | Defaults to -1.
41 | """
42 | grads = [
43 | param.grad.data for param in params
44 | if param.requires_grad and param.grad is not None
45 | ]
46 | world_size = dist.get_world_size()
47 | if coalesce:
48 | _allreduce_coalesced(grads, world_size, bucket_size_mb)
49 | else:
50 | for tensor in grads:
51 | dist.all_reduce(tensor.div_(world_size))
52 |
53 |
54 | class DistOptimizerHook(OptimizerHook):
55 | """Deprecated optimizer hook for distributed training."""
56 |
57 | def __init__(self, *args, **kwargs):
58 | warnings.warn('"DistOptimizerHook" is deprecated, please switch to'
59 | '"mmcv.runner.OptimizerHook".')
60 | super().__init__(*args, **kwargs)
61 |
62 |
63 | def reduce_mean(tensor):
64 | """"Obtain the mean of tensor on different GPUs."""
65 | if not (dist.is_available() and dist.is_initialized()):
66 | return tensor
67 | tensor = tensor.clone()
68 | dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
69 | return tensor
70 |
--------------------------------------------------------------------------------
/mmdet/core/utils/misc.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from six.moves import map, zip
5 |
6 |
7 | def multi_apply(func, *args, **kwargs):
8 | """Apply function to a list of arguments.
9 |
10 | Note:
11 | This function applies the ``func`` to multiple inputs and
12 | map the multiple outputs of the ``func`` into different
13 | list. Each list contains the same type of outputs corresponding
14 | to different inputs.
15 |
16 | Args:
17 | func (Function): A function that will be applied to a list of
18 | arguments
19 |
20 | Returns:
21 | tuple(list): A tuple containing multiple list, each list contains \
22 | a kind of returned results by the function
23 | """
24 | pfunc = partial(func, **kwargs) if kwargs else func
25 | map_results = map(pfunc, *args)
26 | return tuple(map(list, zip(*map_results)))
27 |
28 |
29 | def unmap(data, count, inds, fill=0):
30 | """Unmap a subset of item (data) back to the original set of items (of size
31 | count)"""
32 | if data.dim() == 1:
33 | ret = data.new_full((count, ), fill)
34 | ret[inds.type(torch.bool)] = data
35 | else:
36 | new_size = (count, ) + data.size()[1:]
37 | ret = data.new_full(new_size, fill)
38 | ret[inds.type(torch.bool), :] = data
39 | return ret
40 |
--------------------------------------------------------------------------------
/mmdet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
2 | from .cityscapes import CityscapesDataset
3 | from .coco import CocoDataset
4 | from .custom import CustomDataset
5 | from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
6 | RepeatDataset)
7 | from .deepfashion import DeepFashionDataset
8 | from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
9 | from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
10 | from .utils import replace_ImageToTensor
11 | from .voc import VOCDataset
12 | from .wider_face import WIDERFaceDataset
13 | from .xml_style import XMLDataset
14 |
15 | from .voc_ss import VOCSSDataset
16 |
17 | __all__ = [
18 | 'CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
19 | 'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
20 | 'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
21 | 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
22 | 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
23 | 'build_dataset', 'replace_ImageToTensor',
24 | 'VOCSSDataset'
25 | ]
26 |
--------------------------------------------------------------------------------
/mmdet/datasets/deepfashion.py:
--------------------------------------------------------------------------------
1 | from .builder import DATASETS
2 | from .coco import CocoDataset
3 |
4 |
5 | @DATASETS.register_module()
6 | class DeepFashionDataset(CocoDataset):
7 |
8 | CLASSES = ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants', 'bag',
9 | 'neckwear', 'headwear', 'eyeglass', 'belt', 'footwear', 'hair',
10 | 'skin', 'face')
11 |
--------------------------------------------------------------------------------
/mmdet/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .auto_augment import (AutoAugment, BrightnessTransform, ColorTransform,
2 | ContrastTransform, EqualizeTransform, Rotate, Shear,
3 | Translate)
4 | from .compose import Compose
5 | from .formating import (Collect, DefaultFormatBundle, ImageToTensor,
6 | ToDataContainer, ToTensor, Transpose, to_tensor)
7 | from .instaboost import InstaBoost
8 | from .loading import (LoadAnnotations, LoadImageFromFile, LoadImageFromWebcam,
9 | LoadMultiChannelImageFromFiles, LoadProposals)
10 | from .test_time_aug import MultiScaleFlipAug
11 | from .transforms import (Albu, CutOut, Expand, MinIoURandomCrop, Normalize,
12 | Pad, PhotoMetricDistortion, RandomCenterCropPad,
13 | RandomCrop, RandomFlip, Resize, SegRescale)
14 |
15 | __all__ = [
16 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
17 | 'Transpose', 'Collect', 'DefaultFormatBundle', 'LoadAnnotations',
18 | 'LoadImageFromFile', 'LoadImageFromWebcam',
19 | 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'MultiScaleFlipAug',
20 | 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale',
21 | 'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu',
22 | 'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear',
23 | 'Rotate', 'ColorTransform', 'EqualizeTransform', 'BrightnessTransform',
24 | 'ContrastTransform', 'Translate'
25 | ]
26 |
--------------------------------------------------------------------------------
/mmdet/datasets/pipelines/compose.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | from mmcv.utils import build_from_cfg
4 |
5 | from ..builder import PIPELINES
6 |
7 |
8 | @PIPELINES.register_module()
9 | class Compose(object):
10 | """Compose multiple transforms sequentially.
11 |
12 | Args:
13 | transforms (Sequence[dict | callable]): Sequence of transform object or
14 | config dict to be composed.
15 | """
16 |
17 | def __init__(self, transforms):
18 | assert isinstance(transforms, collections.abc.Sequence)
19 | self.transforms = []
20 | for transform in transforms:
21 | if isinstance(transform, dict):
22 | transform = build_from_cfg(transform, PIPELINES)
23 | self.transforms.append(transform)
24 | elif callable(transform):
25 | self.transforms.append(transform)
26 | else:
27 | raise TypeError('transform must be callable or a dict')
28 |
29 | def __call__(self, data):
30 | """Call function to apply transforms sequentially.
31 |
32 | Args:
33 | data (dict): A result dict contains the data to transform.
34 |
35 | Returns:
36 | dict: Transformed data.
37 | """
38 |
39 | for t in self.transforms:
40 | data = t(data)
41 | if data is None:
42 | return None
43 | return data
44 |
45 | def __repr__(self):
46 | format_string = self.__class__.__name__ + '('
47 | for t in self.transforms:
48 | format_string += '\n'
49 | format_string += f' {t}'
50 | format_string += '\n)'
51 | return format_string
52 |
--------------------------------------------------------------------------------
/mmdet/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed_sampler import DistributedSampler
2 | from .group_sampler import DistributedGroupSampler, GroupSampler
3 |
4 | __all__ = ['DistributedSampler', 'DistributedGroupSampler', 'GroupSampler']
5 |
--------------------------------------------------------------------------------
/mmdet/datasets/samplers/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DistributedSampler as _DistributedSampler
3 |
4 |
5 | class DistributedSampler(_DistributedSampler):
6 |
7 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
8 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
9 | self.shuffle = shuffle
10 |
11 | def __iter__(self):
12 | # deterministically shuffle based on epoch
13 | if self.shuffle:
14 | g = torch.Generator()
15 | g.manual_seed(self.epoch)
16 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
17 | else:
18 | indices = torch.arange(len(self.dataset)).tolist()
19 |
20 | # add extra samples to make it evenly divisible
21 | indices += indices[:(self.total_size - len(indices))]
22 | assert len(indices) == self.total_size
23 |
24 | # subsample
25 | indices = indices[self.rank:self.total_size:self.num_replicas]
26 | assert len(indices) == self.num_samples
27 |
28 | return iter(indices)
29 |
--------------------------------------------------------------------------------
/mmdet/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import warnings
3 |
4 |
5 | def replace_ImageToTensor(pipelines):
6 | """Replace the ImageToTensor transform in a data pipeline to
7 | DefaultFormatBundle, which is normally useful in batch inference.
8 |
9 | Args:
10 | pipelines (list[dict]): Data pipeline configs.
11 |
12 | Returns:
13 | list: The new pipeline list with all ImageToTensor replaced by
14 | DefaultFormatBundle.
15 |
16 | Examples:
17 | >>> pipelines = [
18 | ... dict(type='LoadImageFromFile'),
19 | ... dict(
20 | ... type='MultiScaleFlipAug',
21 | ... img_scale=(1333, 800),
22 | ... flip=False,
23 | ... transforms=[
24 | ... dict(type='Resize', keep_ratio=True),
25 | ... dict(type='RandomFlip'),
26 | ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
27 | ... dict(type='Pad', size_divisor=32),
28 | ... dict(type='ImageToTensor', keys=['img']),
29 | ... dict(type='Collect', keys=['img']),
30 | ... ])
31 | ... ]
32 | >>> expected_pipelines = [
33 | ... dict(type='LoadImageFromFile'),
34 | ... dict(
35 | ... type='MultiScaleFlipAug',
36 | ... img_scale=(1333, 800),
37 | ... flip=False,
38 | ... transforms=[
39 | ... dict(type='Resize', keep_ratio=True),
40 | ... dict(type='RandomFlip'),
41 | ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
42 | ... dict(type='Pad', size_divisor=32),
43 | ... dict(type='DefaultFormatBundle'),
44 | ... dict(type='Collect', keys=['img']),
45 | ... ])
46 | ... ]
47 | >>> assert expected_pipelines == replace_ImageToTensor(pipelines)
48 | """
49 | pipelines = copy.deepcopy(pipelines)
50 | for i, pipeline in enumerate(pipelines):
51 | if pipeline['type'] == 'MultiScaleFlipAug':
52 | assert 'transforms' in pipeline
53 | pipeline['transforms'] = replace_ImageToTensor(
54 | pipeline['transforms'])
55 | elif pipeline['type'] == 'ImageToTensor':
56 | warnings.warn(
57 | '"ImageToTensor" pipeline is replaced by '
58 | '"DefaultFormatBundle" for batch inference. It is '
59 | 'recommended to manually replace it in the test '
60 | 'data pipeline in your config file.', UserWarning)
61 | pipelines[i] = {'type': 'DefaultFormatBundle'}
62 | return pipelines
63 |
--------------------------------------------------------------------------------
/mmdet/datasets/voc_ss.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import xml.etree.ElementTree as ET
3 |
4 | from .builder import DATASETS
5 | from .voc import VOCDataset
6 | import mmcv
7 |
8 |
9 | @DATASETS.register_module()
10 | class VOCSSDataset(VOCDataset):
11 | # VOC Dataset with SuperPixels
12 |
13 | def __init__(self, **kwargs):
14 | super(VOCSSDataset, self).__init__(**kwargs)
15 |
16 | def load_annotations(self, ann_file):
17 | data_infos = []
18 | img_ids = mmcv.list_from_file(ann_file)
19 | for img_id in img_ids:
20 | filename = f'JPEGImages/{img_id}.jpg'
21 | ssname = f'SuperPixels/{img_id}.jpg'
22 | xml_path = osp.join(self.img_prefix, 'Annotations',
23 | f'{img_id}.xml')
24 | tree = ET.parse(xml_path)
25 | root = tree.getroot()
26 | size = root.find('size')
27 | width = 0
28 | height = 0
29 | if size is not None:
30 | width = int(size.find('width').text)
31 | height = int(size.find('height').text)
32 | else:
33 | img_path = osp.join(self.img_prefix, 'JPEGImages',
34 | '{}.jpg'.format(img_id))
35 | img = Image.open(img_path)
36 | width, height = img.size
37 | data_infos.append(
38 | dict(id=img_id, filename=filename, width=width, height=height, ssname=ssname))
39 |
40 | return data_infos
41 |
42 |
--------------------------------------------------------------------------------
/mmdet/datasets/wider_face.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import xml.etree.ElementTree as ET
3 |
4 | import mmcv
5 |
6 | from .builder import DATASETS
7 | from .xml_style import XMLDataset
8 |
9 |
10 | @DATASETS.register_module()
11 | class WIDERFaceDataset(XMLDataset):
12 | """Reader for the WIDER Face dataset in PASCAL VOC format.
13 |
14 | Conversion scripts can be found in
15 | https://github.com/sovrasov/wider-face-pascal-voc-annotations
16 | """
17 | CLASSES = ('face', )
18 |
19 | def __init__(self, **kwargs):
20 | super(WIDERFaceDataset, self).__init__(**kwargs)
21 |
22 | def load_annotations(self, ann_file):
23 | """Load annotation from WIDERFace XML style annotation file.
24 |
25 | Args:
26 | ann_file (str): Path of XML file.
27 |
28 | Returns:
29 | list[dict]: Annotation info from XML file.
30 | """
31 |
32 | data_infos = []
33 | img_ids = mmcv.list_from_file(ann_file)
34 | for img_id in img_ids:
35 | filename = f'{img_id}.jpg'
36 | xml_path = osp.join(self.img_prefix, 'Annotations',
37 | f'{img_id}.xml')
38 | tree = ET.parse(xml_path)
39 | root = tree.getroot()
40 | size = root.find('size')
41 | width = int(size.find('width').text)
42 | height = int(size.find('height').text)
43 | folder = root.find('folder').text
44 | data_infos.append(
45 | dict(
46 | id=img_id,
47 | filename=osp.join(folder, filename),
48 | width=width,
49 | height=height))
50 |
51 | return data_infos
52 |
--------------------------------------------------------------------------------
/mmdet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbones import * # noqa: F401,F403
2 | from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
3 | ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
4 | build_detector, build_head, build_loss, build_neck,
5 | build_roi_extractor, build_shared_head)
6 | from .dense_heads import * # noqa: F401,F403
7 | from .detectors import * # noqa: F401,F403
8 | from .losses import * # noqa: F401,F403
9 | from .necks import * # noqa: F401,F403
10 | from .roi_heads import * # noqa: F401,F403
11 |
12 | __all__ = [
13 | 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
14 | 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
15 | 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
16 | ]
17 |
--------------------------------------------------------------------------------
/mmdet/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .darknet import Darknet
2 | from .detectors_resnet import DetectoRS_ResNet
3 | from .detectors_resnext import DetectoRS_ResNeXt
4 | from .hourglass import HourglassNet
5 | from .hrnet import HRNet
6 | from .regnet import RegNet
7 | from .res2net import Res2Net
8 | from .resnest import ResNeSt
9 | from .resnet import ResNet, ResNetV1d
10 | from .resnext import ResNeXt
11 | from .ssd_vgg import SSDVGG
12 | from .vgg import VGG16
13 |
14 | __all__ = [
15 | 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
16 | 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
17 | 'ResNeSt', 'VGG16'
18 | ]
19 |
--------------------------------------------------------------------------------
/mmdet/models/backbones/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision.models as models
4 | from ..builder import BACKBONES
5 |
6 | @BACKBONES.register_module()
7 | class VGG16(nn.Module):
8 | def __init__(self):
9 | super(VGG16, self).__init__()
10 |
11 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True)
12 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
13 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
14 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True)
15 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True)
16 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
17 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True)
18 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True)
19 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True)
20 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
21 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True)
22 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True)
23 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True)
24 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)
25 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)
26 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)
27 |
28 | self.conv1_1.requires_grad = False
29 | self.conv1_2.requires_grad = False
30 | self.conv2_1.requires_grad = False
31 | self.conv2_2.requires_grad = False
32 |
33 | def forward(self, x):
34 | x = F.relu(self.conv1_1(x), inplace=True)
35 | x = F.relu(self.conv1_2(x), inplace=True)
36 | x = self.pool1(x)
37 | x = F.relu(self.conv2_1(x), inplace=True)
38 | x = F.relu(self.conv2_2(x), inplace=True)
39 | x = self.pool2(x)
40 | x = F.relu(self.conv3_1(x), inplace=True)
41 | x = F.relu(self.conv3_2(x), inplace=True)
42 | x = F.relu(self.conv3_3(x), inplace=True)
43 | x = self.pool3(x)
44 | x = F.relu(self.conv4_1(x), inplace=True)
45 | x = F.relu(self.conv4_2(x), inplace=True)
46 | x = F.relu(self.conv4_3(x), inplace=True)
47 | x = F.relu(self.conv5_1(x), inplace=True)
48 | x = F.relu(self.conv5_2(x), inplace=True)
49 | x = F.relu(self.conv5_3(x), inplace=True)
50 | return [x]
51 |
52 | def init_weights(self, pretrained=None):
53 | pass
54 |
--------------------------------------------------------------------------------
/mmdet/models/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import Registry, build_from_cfg
2 | from torch import nn
3 |
4 | BACKBONES = Registry('backbone')
5 | NECKS = Registry('neck')
6 | ROI_EXTRACTORS = Registry('roi_extractor')
7 | SHARED_HEADS = Registry('shared_head')
8 | HEADS = Registry('head')
9 | LOSSES = Registry('loss')
10 | DETECTORS = Registry('detector')
11 |
12 |
13 | def build(cfg, registry, default_args=None):
14 | """Build a module.
15 |
16 | Args:
17 | cfg (dict, list[dict]): The config of modules, is is either a dict
18 | or a list of configs.
19 | registry (:obj:`Registry`): A registry the module belongs to.
20 | default_args (dict, optional): Default arguments to build the module.
21 | Defaults to None.
22 |
23 | Returns:
24 | nn.Module: A built nn module.
25 | """
26 | if isinstance(cfg, list):
27 | modules = [
28 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
29 | ]
30 | return nn.Sequential(*modules)
31 | else:
32 | return build_from_cfg(cfg, registry, default_args)
33 |
34 |
35 | def build_backbone(cfg):
36 | """Build backbone."""
37 | return build(cfg, BACKBONES)
38 |
39 |
40 | def build_neck(cfg):
41 | """Build neck."""
42 | return build(cfg, NECKS)
43 |
44 |
45 | def build_roi_extractor(cfg):
46 | """Build roi extractor."""
47 | return build(cfg, ROI_EXTRACTORS)
48 |
49 |
50 | def build_shared_head(cfg):
51 | """Build shared head."""
52 | return build(cfg, SHARED_HEADS)
53 |
54 |
55 | def build_head(cfg):
56 | """Build head."""
57 | return build(cfg, HEADS)
58 |
59 |
60 | def build_loss(cfg):
61 | """Build loss."""
62 | return build(cfg, LOSSES)
63 |
64 |
65 | def build_detector(cfg, train_cfg=None, test_cfg=None):
66 | """Build detector."""
67 | return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
68 |
--------------------------------------------------------------------------------
/mmdet/models/dense_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .anchor_free_head import AnchorFreeHead
2 | from .anchor_head import AnchorHead
3 | from .atss_head import ATSSHead
4 | from .centripetal_head import CentripetalHead
5 | from .corner_head import CornerHead
6 | from .fcos_head import FCOSHead
7 | from .fovea_head import FoveaHead
8 | from .free_anchor_retina_head import FreeAnchorRetinaHead
9 | from .fsaf_head import FSAFHead
10 | from .ga_retina_head import GARetinaHead
11 | from .ga_rpn_head import GARPNHead
12 | from .gfl_head import GFLHead
13 | from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
14 | from .nasfcos_head import NASFCOSHead
15 | from .paa_head import PAAHead
16 | from .pisa_retinanet_head import PISARetinaHead
17 | from .pisa_ssd_head import PISASSDHead
18 | from .reppoints_head import RepPointsHead
19 | from .retina_head import RetinaHead
20 | from .retina_sepbn_head import RetinaSepBNHead
21 | from .rpn_head import RPNHead
22 | from .sabl_retina_head import SABLRetinaHead
23 | from .ssd_head import SSDHead
24 | from .vfnet_head import VFNetHead
25 | from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
26 | from .yolo_head import YOLOV3Head
27 |
28 | __all__ = [
29 | 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
30 | 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
31 | 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
32 | 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
33 | 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
34 | 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
35 | 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead'
36 | ]
37 |
--------------------------------------------------------------------------------
/mmdet/models/dense_heads/base_dense_head.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class BaseDenseHead(nn.Module, metaclass=ABCMeta):
7 | """Base class for DenseHeads."""
8 |
9 | def __init__(self):
10 | super(BaseDenseHead, self).__init__()
11 |
12 | @abstractmethod
13 | def loss(self, **kwargs):
14 | """Compute losses of the head."""
15 | pass
16 |
17 | @abstractmethod
18 | def get_bboxes(self, **kwargs):
19 | """Transform network output for a batch into bbox predictions."""
20 | pass
21 |
22 | def forward_train(self,
23 | x,
24 | img_metas,
25 | gt_bboxes,
26 | gt_labels=None,
27 | gt_bboxes_ignore=None,
28 | proposal_cfg=None,
29 | **kwargs):
30 | """
31 | Args:
32 | x (list[Tensor]): Features from FPN.
33 | img_metas (list[dict]): Meta information of each image, e.g.,
34 | image size, scaling factor, etc.
35 | gt_bboxes (Tensor): Ground truth bboxes of the image,
36 | shape (num_gts, 4).
37 | gt_labels (Tensor): Ground truth labels of each box,
38 | shape (num_gts,).
39 | gt_bboxes_ignore (Tensor): Ground truth bboxes to be
40 | ignored, shape (num_ignored_gts, 4).
41 | proposal_cfg (mmcv.Config): Test / postprocessing configuration,
42 | if None, test_cfg would be used
43 |
44 | Returns:
45 | tuple:
46 | losses: (dict[str, Tensor]): A dictionary of loss components.
47 | proposal_list (list[Tensor]): Proposals of each image.
48 | """
49 | outs = self(x)
50 | if gt_labels is None:
51 | loss_inputs = outs + (gt_bboxes, img_metas)
52 | else:
53 | loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
54 | losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
55 | if proposal_cfg is None:
56 | return losses
57 | else:
58 | proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
59 | return losses, proposal_list
60 |
--------------------------------------------------------------------------------
/mmdet/models/dense_heads/nasfcos_head.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch.nn as nn
4 | from mmcv.cnn import (ConvModule, Scale, bias_init_with_prob,
5 | caffe2_xavier_init, normal_init)
6 |
7 | from mmdet.models.dense_heads.fcos_head import FCOSHead
8 | from ..builder import HEADS
9 |
10 |
11 | @HEADS.register_module()
12 | class NASFCOSHead(FCOSHead):
13 | """Anchor-free head used in `NASFCOS `_.
14 |
15 | It is quite similar with FCOS head, except for the searched structure of
16 | classification branch and bbox regression branch, where a structure of
17 | "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead.
18 | """
19 |
20 | def _init_layers(self):
21 | """Initialize layers of the head."""
22 | dconv3x3_config = dict(
23 | type='DCNv2',
24 | kernel_size=3,
25 | use_bias=True,
26 | deform_groups=2,
27 | padding=1)
28 | conv3x3_config = dict(type='Conv', kernel_size=3, padding=1)
29 | conv1x1_config = dict(type='Conv', kernel_size=1)
30 |
31 | self.arch_config = [
32 | dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config
33 | ]
34 | self.cls_convs = nn.ModuleList()
35 | self.reg_convs = nn.ModuleList()
36 | for i, op_ in enumerate(self.arch_config):
37 | op = copy.deepcopy(op_)
38 | chn = self.in_channels if i == 0 else self.feat_channels
39 | assert isinstance(op, dict)
40 | use_bias = op.pop('use_bias', False)
41 | padding = op.pop('padding', 0)
42 | kernel_size = op.pop('kernel_size')
43 | module = ConvModule(
44 | chn,
45 | self.feat_channels,
46 | kernel_size,
47 | stride=1,
48 | padding=padding,
49 | norm_cfg=self.norm_cfg,
50 | bias=use_bias,
51 | conv_cfg=op)
52 |
53 | self.cls_convs.append(copy.deepcopy(module))
54 | self.reg_convs.append(copy.deepcopy(module))
55 |
56 | self.conv_cls = nn.Conv2d(
57 | self.feat_channels, self.cls_out_channels, 3, padding=1)
58 | self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
59 | self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
60 |
61 | self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
62 |
63 | def init_weights(self):
64 | """Initialize weights of the head."""
65 | # retinanet_bias_init
66 | bias_cls = bias_init_with_prob(0.01)
67 | normal_init(self.conv_reg, std=0.01)
68 | normal_init(self.conv_centerness, std=0.01)
69 | normal_init(self.conv_cls, std=0.01, bias=bias_cls)
70 |
71 | for branch in [self.cls_convs, self.reg_convs]:
72 | for module in branch.modules():
73 | if isinstance(module, ConvModule) \
74 | and isinstance(module.conv, nn.Conv2d):
75 | caffe2_xavier_init(module.conv)
76 |
--------------------------------------------------------------------------------
/mmdet/models/dense_heads/rpn_test_mixin.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from mmdet.core import merge_aug_proposals
4 |
5 | if sys.version_info >= (3, 7):
6 | from mmdet.utils.contextmanagers import completed
7 |
8 |
9 | class RPNTestMixin(object):
10 | """Test methods of RPN."""
11 |
12 | if sys.version_info >= (3, 7):
13 |
14 | async def async_simple_test_rpn(self, x, img_metas):
15 | sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
16 | async with completed(
17 | __name__, 'rpn_head_forward',
18 | sleep_interval=sleep_interval):
19 | rpn_outs = self(x)
20 |
21 | proposal_list = self.get_bboxes(*rpn_outs, img_metas)
22 | return proposal_list
23 |
24 | def simple_test_rpn(self, x, img_metas):
25 | """Test without augmentation.
26 |
27 | Args:
28 | x (tuple[Tensor]): Features from the upstream network, each is
29 | a 4D-tensor.
30 | img_metas (list[dict]): Meta info of each image.
31 |
32 | Returns:
33 | list[Tensor]: Proposals of each image.
34 | """
35 | rpn_outs = self(x)
36 | proposal_list = self.get_bboxes(*rpn_outs, img_metas)
37 | return proposal_list
38 |
39 | def aug_test_rpn(self, feats, img_metas):
40 | samples_per_gpu = len(img_metas[0])
41 | aug_proposals = [[] for _ in range(samples_per_gpu)]
42 | for x, img_meta in zip(feats, img_metas):
43 | proposal_list = self.simple_test_rpn(x, img_meta)
44 | for i, proposals in enumerate(proposal_list):
45 | aug_proposals[i].append(proposals)
46 | # reorganize the order of 'img_metas' to match the dimensions
47 | # of 'aug_proposals'
48 | aug_img_metas = []
49 | for i in range(samples_per_gpu):
50 | aug_img_meta = []
51 | for j in range(len(img_metas)):
52 | aug_img_meta.append(img_metas[j][i])
53 | aug_img_metas.append(aug_img_meta)
54 | # after merging, proposals will be rescaled to the original image size
55 | merged_proposals = [
56 | merge_aug_proposals(proposals, aug_img_meta, self.test_cfg)
57 | for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
58 | ]
59 | return merged_proposals
60 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/__init__.py:
--------------------------------------------------------------------------------
1 | from .atss import ATSS
2 | from .base import BaseDetector
3 | from .cascade_rcnn import CascadeRCNN
4 | from .cornernet import CornerNet
5 | from .fast_rcnn import FastRCNN
6 | from .faster_rcnn import FasterRCNN
7 | from .fcos import FCOS
8 | from .fovea import FOVEA
9 | from .fsaf import FSAF
10 | from .gfl import GFL
11 | from .grid_rcnn import GridRCNN
12 | from .htc import HybridTaskCascade
13 | from .mask_rcnn import MaskRCNN
14 | from .mask_scoring_rcnn import MaskScoringRCNN
15 | from .nasfcos import NASFCOS
16 | from .paa import PAA
17 | from .point_rend import PointRend
18 | from .reppoints_detector import RepPointsDetector
19 | from .retinanet import RetinaNet
20 | from .rpn import RPN
21 | from .single_stage import SingleStageDetector
22 | from .two_stage import TwoStageDetector
23 | from .vfnet import VFNet
24 | from .yolact import YOLACT
25 | from .yolo import YOLOV3
26 |
27 | from .weak_rcnn import WeakRCNN
28 |
29 | __all__ = [
30 | 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
31 | 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
32 | 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
33 | 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
34 | 'YOLOV3', 'YOLACT', 'VFNet',
35 | 'WeakRCNN'
36 | ]
37 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/atss.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class ATSS(SingleStageDetector):
7 | """Implementation of `ATSS `_."""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | bbox_head,
13 | train_cfg=None,
14 | test_cfg=None,
15 | pretrained=None):
16 | super(ATSS, self).__init__(backbone, neck, bbox_head, train_cfg,
17 | test_cfg, pretrained)
18 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/cascade_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class CascadeRCNN(TwoStageDetector):
7 | r"""Implementation of `Cascade R-CNN: Delving into High Quality Object
8 | Detection `_"""
9 |
10 | def __init__(self,
11 | backbone,
12 | neck=None,
13 | rpn_head=None,
14 | roi_head=None,
15 | train_cfg=None,
16 | test_cfg=None,
17 | pretrained=None):
18 | super(CascadeRCNN, self).__init__(
19 | backbone=backbone,
20 | neck=neck,
21 | rpn_head=rpn_head,
22 | roi_head=roi_head,
23 | train_cfg=train_cfg,
24 | test_cfg=test_cfg,
25 | pretrained=pretrained)
26 |
27 | def show_result(self, data, result, **kwargs):
28 | """Show prediction results of the detector."""
29 | if self.with_mask:
30 | ms_bbox_result, ms_segm_result = result
31 | if isinstance(ms_bbox_result, dict):
32 | result = (ms_bbox_result['ensemble'],
33 | ms_segm_result['ensemble'])
34 | else:
35 | if isinstance(result, dict):
36 | result = result['ensemble']
37 | return super(CascadeRCNN, self).show_result(data, result, **kwargs)
38 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/fast_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class FastRCNN(TwoStageDetector):
7 | """Implementation of `Fast R-CNN `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | roi_head,
12 | train_cfg,
13 | test_cfg,
14 | neck=None,
15 | pretrained=None):
16 | super(FastRCNN, self).__init__(
17 | backbone=backbone,
18 | neck=neck,
19 | roi_head=roi_head,
20 | train_cfg=train_cfg,
21 | test_cfg=test_cfg,
22 | pretrained=pretrained)
23 |
24 | def forward_test(self, imgs, img_metas, proposals, **kwargs):
25 | """
26 | Args:
27 | imgs (List[Tensor]): the outer list indicates test-time
28 | augmentations and inner Tensor should have a shape NxCxHxW,
29 | which contains all images in the batch.
30 | img_metas (List[List[dict]]): the outer list indicates test-time
31 | augs (multiscale, flip, etc.) and the inner list indicates
32 | images in a batch.
33 | proposals (List[List[Tensor]]): the outer list indicates test-time
34 | augs (multiscale, flip, etc.) and the inner list indicates
35 | images in a batch. The Tensor should have a shape Px4, where
36 | P is the number of proposals.
37 | """
38 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
39 | if not isinstance(var, list):
40 | raise TypeError(f'{name} must be a list, but got {type(var)}')
41 |
42 | num_augs = len(imgs)
43 | if num_augs != len(img_metas):
44 | raise ValueError(f'num of augmentations ({len(imgs)}) '
45 | f'!= num of image meta ({len(img_metas)})')
46 |
47 | if num_augs == 1:
48 | return self.simple_test(imgs[0], img_metas[0], proposals[0],
49 | **kwargs)
50 | else:
51 | # TODO: support test-time augmentation
52 | assert NotImplementedError
53 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/faster_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class FasterRCNN(TwoStageDetector):
7 | """Implementation of `Faster R-CNN `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | rpn_head,
12 | roi_head,
13 | train_cfg,
14 | test_cfg,
15 | neck=None,
16 | pretrained=None):
17 | super(FasterRCNN, self).__init__(
18 | backbone=backbone,
19 | neck=neck,
20 | rpn_head=rpn_head,
21 | roi_head=roi_head,
22 | train_cfg=train_cfg,
23 | test_cfg=test_cfg,
24 | pretrained=pretrained)
25 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/fcos.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class FCOS(SingleStageDetector):
7 | """Implementation of `FCOS `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | bbox_head,
13 | train_cfg=None,
14 | test_cfg=None,
15 | pretrained=None):
16 | super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
17 | test_cfg, pretrained)
18 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/fovea.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class FOVEA(SingleStageDetector):
7 | """Implementation of `FoveaBox `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | bbox_head,
13 | train_cfg=None,
14 | test_cfg=None,
15 | pretrained=None):
16 | super(FOVEA, self).__init__(backbone, neck, bbox_head, train_cfg,
17 | test_cfg, pretrained)
18 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/fsaf.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class FSAF(SingleStageDetector):
7 | """Implementation of `FSAF `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | bbox_head,
13 | train_cfg=None,
14 | test_cfg=None,
15 | pretrained=None):
16 | super(FSAF, self).__init__(backbone, neck, bbox_head, train_cfg,
17 | test_cfg, pretrained)
18 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/gfl.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class GFL(SingleStageDetector):
7 |
8 | def __init__(self,
9 | backbone,
10 | neck,
11 | bbox_head,
12 | train_cfg=None,
13 | test_cfg=None,
14 | pretrained=None):
15 | super(GFL, self).__init__(backbone, neck, bbox_head, train_cfg,
16 | test_cfg, pretrained)
17 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/grid_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class GridRCNN(TwoStageDetector):
7 | """Grid R-CNN.
8 |
9 | This detector is the implementation of:
10 | - Grid R-CNN (https://arxiv.org/abs/1811.12030)
11 | - Grid R-CNN Plus: Faster and Better (https://arxiv.org/abs/1906.05688)
12 | """
13 |
14 | def __init__(self,
15 | backbone,
16 | rpn_head,
17 | roi_head,
18 | train_cfg,
19 | test_cfg,
20 | neck=None,
21 | pretrained=None):
22 | super(GridRCNN, self).__init__(
23 | backbone=backbone,
24 | neck=neck,
25 | rpn_head=rpn_head,
26 | roi_head=roi_head,
27 | train_cfg=train_cfg,
28 | test_cfg=test_cfg,
29 | pretrained=pretrained)
30 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/htc.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .cascade_rcnn import CascadeRCNN
3 |
4 |
5 | @DETECTORS.register_module()
6 | class HybridTaskCascade(CascadeRCNN):
7 | """Implementation of `HTC `_"""
8 |
9 | def __init__(self, **kwargs):
10 | super(HybridTaskCascade, self).__init__(**kwargs)
11 |
12 | @property
13 | def with_semantic(self):
14 | """bool: whether the detector has a semantic head"""
15 | return self.roi_head.with_semantic
16 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/mask_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class MaskRCNN(TwoStageDetector):
7 | """Implementation of `Mask R-CNN `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | rpn_head,
12 | roi_head,
13 | train_cfg,
14 | test_cfg,
15 | neck=None,
16 | pretrained=None):
17 | super(MaskRCNN, self).__init__(
18 | backbone=backbone,
19 | neck=neck,
20 | rpn_head=rpn_head,
21 | roi_head=roi_head,
22 | train_cfg=train_cfg,
23 | test_cfg=test_cfg,
24 | pretrained=pretrained)
25 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/mask_scoring_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class MaskScoringRCNN(TwoStageDetector):
7 | """Mask Scoring RCNN.
8 |
9 | https://arxiv.org/abs/1903.00241
10 | """
11 |
12 | def __init__(self,
13 | backbone,
14 | rpn_head,
15 | roi_head,
16 | train_cfg,
17 | test_cfg,
18 | neck=None,
19 | pretrained=None):
20 | super(MaskScoringRCNN, self).__init__(
21 | backbone=backbone,
22 | neck=neck,
23 | rpn_head=rpn_head,
24 | roi_head=roi_head,
25 | train_cfg=train_cfg,
26 | test_cfg=test_cfg,
27 | pretrained=pretrained)
28 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/nasfcos.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class NASFCOS(SingleStageDetector):
7 | """NAS-FCOS: Fast Neural Architecture Search for Object Detection.
8 |
9 | https://arxiv.org/abs/1906.0442
10 | """
11 |
12 | def __init__(self,
13 | backbone,
14 | neck,
15 | bbox_head,
16 | train_cfg=None,
17 | test_cfg=None,
18 | pretrained=None):
19 | super(NASFCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
20 | test_cfg, pretrained)
21 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/paa.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class PAA(SingleStageDetector):
7 | """Implementation of `PAA `_."""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | bbox_head,
13 | train_cfg=None,
14 | test_cfg=None,
15 | pretrained=None):
16 | super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg,
17 | test_cfg, pretrained)
18 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/point_rend.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class PointRend(TwoStageDetector):
7 | """PointRend: Image Segmentation as Rendering
8 |
9 | This detector is the implementation of
10 | `PointRend `_.
11 |
12 | """
13 |
14 | def __init__(self,
15 | backbone,
16 | rpn_head,
17 | roi_head,
18 | train_cfg,
19 | test_cfg,
20 | neck=None,
21 | pretrained=None):
22 | super(PointRend, self).__init__(
23 | backbone=backbone,
24 | neck=neck,
25 | rpn_head=rpn_head,
26 | roi_head=roi_head,
27 | train_cfg=train_cfg,
28 | test_cfg=test_cfg,
29 | pretrained=pretrained)
30 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/reppoints_detector.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class RepPointsDetector(SingleStageDetector):
7 | """RepPoints: Point Set Representation for Object Detection.
8 |
9 | This detector is the implementation of:
10 | - RepPoints detector (https://arxiv.org/pdf/1904.11490)
11 | """
12 |
13 | def __init__(self,
14 | backbone,
15 | neck,
16 | bbox_head,
17 | train_cfg=None,
18 | test_cfg=None,
19 | pretrained=None):
20 | super(RepPointsDetector,
21 | self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
22 | pretrained)
23 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/retinanet.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class RetinaNet(SingleStageDetector):
7 | """Implementation of `RetinaNet `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | bbox_head,
13 | train_cfg=None,
14 | test_cfg=None,
15 | pretrained=None):
16 | super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg,
17 | test_cfg, pretrained)
18 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/vfnet.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .single_stage import SingleStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class VFNet(SingleStageDetector):
7 | """Implementation of `VarifocalNet
8 | (VFNet).`_"""
9 |
10 | def __init__(self,
11 | backbone,
12 | neck,
13 | bbox_head,
14 | train_cfg=None,
15 | test_cfg=None,
16 | pretrained=None):
17 | super(VFNet, self).__init__(backbone, neck, bbox_head, train_cfg,
18 | test_cfg, pretrained)
19 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/weak_rcnn.py:
--------------------------------------------------------------------------------
1 | from ..builder import DETECTORS
2 | from .two_stage import TwoStageDetector
3 |
4 |
5 | @DETECTORS.register_module()
6 | class WeakRCNN(TwoStageDetector):
7 | """Implementation of `Fast R-CNN `_"""
8 |
9 | def __init__(self,
10 | backbone,
11 | neck,
12 | roi_head,
13 | train_cfg,
14 | test_cfg,
15 | pretrained=None):
16 | super(WeakRCNN, self).__init__(
17 | backbone=backbone,
18 | neck=neck,
19 | roi_head=roi_head,
20 | train_cfg=train_cfg,
21 | test_cfg=test_cfg,
22 | pretrained=pretrained)
23 |
24 | def forward_train(self,
25 | img,
26 | img_metas,
27 | gt_labels,
28 | proposals=None,
29 | **kwargs):
30 |
31 | x = self.extract_feat(img)
32 |
33 | losses = dict()
34 |
35 | proposal_list = proposals
36 |
37 | roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
38 | gt_labels, **kwargs)
39 | losses.update(roi_losses)
40 | return losses
41 |
42 |
43 | def forward_test(self, imgs, img_metas, proposals, **kwargs):
44 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
45 | if not isinstance(var, list):
46 | raise TypeError(f'{name} must be a list, but got {type(var)}')
47 |
48 | num_augs = len(imgs)
49 | if num_augs != len(img_metas):
50 | raise ValueError(f'num of augmentations ({len(imgs)}) '
51 | f'!= num of image meta ({len(img_metas)})')
52 |
53 | if num_augs == 1:
54 | return self.simple_test(imgs[0], img_metas[0], proposals[0],
55 | **kwargs)
56 | else:
57 | assert imgs[0].size(0) == 1, 'aug test does not support ' \
58 | 'inference with batch size ' \
59 | f'{imgs[0].size(0)}'
60 | return self.aug_test(imgs, img_metas, proposals, **kwargs)
61 |
62 |
63 | def aug_test(self, imgs, img_metas, proposal_list, rescale=False):
64 | """Test with augmentations.
65 |
66 | If rescale is False, then returned bboxes and masks will fit the scale
67 | of imgs[0].
68 | """
69 | x = self.extract_feats(imgs)
70 | return self.roi_head.aug_test(
71 | x, proposal_list, img_metas, rescale=rescale)
72 |
--------------------------------------------------------------------------------
/mmdet/models/detectors/yolo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Western Digital Corporation or its affiliates.
2 |
3 | from ..builder import DETECTORS
4 | from .single_stage import SingleStageDetector
5 |
6 |
7 | @DETECTORS.register_module()
8 | class YOLOV3(SingleStageDetector):
9 |
10 | def __init__(self,
11 | backbone,
12 | neck,
13 | bbox_head,
14 | train_cfg=None,
15 | test_cfg=None,
16 | pretrained=None):
17 | super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
18 | test_cfg, pretrained)
19 |
--------------------------------------------------------------------------------
/mmdet/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .accuracy import Accuracy, accuracy
2 | from .ae_loss import AssociativeEmbeddingLoss
3 | from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
5 | cross_entropy, mask_cross_entropy)
6 | from .focal_loss import FocalLoss, sigmoid_focal_loss
7 | from .gaussian_focal_loss import GaussianFocalLoss
8 | from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
9 | from .ghm_loss import GHMC, GHMR
10 | from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss,
11 | bounded_iou_loss, iou_loss)
12 | from .mse_loss import MSELoss, mse_loss
13 | from .pisa_loss import carl_loss, isr_p
14 | from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss
15 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss
16 | from .varifocal_loss import VarifocalLoss
17 |
18 | __all__ = [
19 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
20 | 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
21 | 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
22 | 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
23 | 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 'GHMC',
24 | 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
25 | 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
26 | 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
27 | 'VarifocalLoss'
28 | ]
29 |
--------------------------------------------------------------------------------
/mmdet/models/losses/accuracy.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def accuracy(pred, target, topk=1, thresh=None):
5 | """Calculate accuracy according to the prediction and target.
6 |
7 | Args:
8 | pred (torch.Tensor): The model prediction, shape (N, num_class)
9 | target (torch.Tensor): The target of each prediction, shape (N, )
10 | topk (int | tuple[int], optional): If the predictions in ``topk``
11 | matches the target, the predictions will be regarded as
12 | correct ones. Defaults to 1.
13 | thresh (float, optional): If not None, predictions with scores under
14 | this threshold are considered incorrect. Default to None.
15 |
16 | Returns:
17 | float | tuple[float]: If the input ``topk`` is a single integer,
18 | the function will return a single float as accuracy. If
19 | ``topk`` is a tuple containing multiple integers, the
20 | function will return a tuple containing accuracies of
21 | each ``topk`` number.
22 | """
23 | assert isinstance(topk, (int, tuple))
24 | if isinstance(topk, int):
25 | topk = (topk, )
26 | return_single = True
27 | else:
28 | return_single = False
29 |
30 | maxk = max(topk)
31 | if pred.size(0) == 0:
32 | accu = [pred.new_tensor(0.) for i in range(len(topk))]
33 | return accu[0] if return_single else accu
34 | assert pred.ndim == 2 and target.ndim == 1
35 | assert pred.size(0) == target.size(0)
36 | assert maxk <= pred.size(1), \
37 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
38 | pred_value, pred_label = pred.topk(maxk, dim=1)
39 | pred_label = pred_label.t() # transpose to shape (maxk, N)
40 | correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
41 | if thresh is not None:
42 | # Only prediction values larger than thresh are counted as correct
43 | correct = correct & (pred_value > thresh).t()
44 | res = []
45 | for k in topk:
46 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
47 | res.append(correct_k.mul_(100.0 / pred.size(0)))
48 | return res[0] if return_single else res
49 |
50 |
51 | class Accuracy(nn.Module):
52 |
53 | def __init__(self, topk=(1, ), thresh=None):
54 | """Module to calculate the accuracy.
55 |
56 | Args:
57 | topk (tuple, optional): The criterion used to calculate the
58 | accuracy. Defaults to (1,).
59 | thresh (float, optional): If not None, predictions with scores
60 | under this threshold are considered incorrect. Default to None.
61 | """
62 | super().__init__()
63 | self.topk = topk
64 | self.thresh = thresh
65 |
66 | def forward(self, pred, target):
67 | """Forward function to calculate accuracy.
68 |
69 | Args:
70 | pred (torch.Tensor): Prediction of models.
71 | target (torch.Tensor): Target for each prediction.
72 |
73 | Returns:
74 | tuple[float]: The accuracies under different topk criterions.
75 | """
76 | return accuracy(pred, target, self.topk, self.thresh)
77 |
--------------------------------------------------------------------------------
/mmdet/models/losses/gaussian_focal_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from ..builder import LOSSES
4 | from .utils import weighted_loss
5 |
6 |
7 | @weighted_loss
8 | def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0):
9 | """`Focal Loss `_ for targets in gaussian
10 | distribution.
11 |
12 | Args:
13 | pred (torch.Tensor): The prediction.
14 | gaussian_target (torch.Tensor): The learning target of the prediction
15 | in gaussian distribution.
16 | alpha (float, optional): A balanced form for Focal Loss.
17 | Defaults to 2.0.
18 | gamma (float, optional): The gamma for calculating the modulating
19 | factor. Defaults to 4.0.
20 | """
21 | eps = 1e-12
22 | pos_weights = gaussian_target.eq(1)
23 | neg_weights = (1 - gaussian_target).pow(gamma)
24 | pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
25 | neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
26 | return pos_loss + neg_loss
27 |
28 |
29 | @LOSSES.register_module()
30 | class GaussianFocalLoss(nn.Module):
31 | """GaussianFocalLoss is a variant of focal loss.
32 |
33 | More details can be found in the `paper
34 | `_
35 | Code is modified from `kp_utils.py
36 | `_ # noqa: E501
37 | Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
38 | not 0/1 binary target.
39 |
40 | Args:
41 | alpha (float): Power of prediction.
42 | gamma (float): Power of target for negtive samples.
43 | reduction (str): Options are "none", "mean" and "sum".
44 | loss_weight (float): Loss weight of current loss.
45 | """
46 |
47 | def __init__(self,
48 | alpha=2.0,
49 | gamma=4.0,
50 | reduction='mean',
51 | loss_weight=1.0):
52 | super(GaussianFocalLoss, self).__init__()
53 | self.alpha = alpha
54 | self.gamma = gamma
55 | self.reduction = reduction
56 | self.loss_weight = loss_weight
57 |
58 | def forward(self,
59 | pred,
60 | target,
61 | weight=None,
62 | avg_factor=None,
63 | reduction_override=None):
64 | """Forward function.
65 |
66 | Args:
67 | pred (torch.Tensor): The prediction.
68 | target (torch.Tensor): The learning target of the prediction
69 | in gaussian distribution.
70 | weight (torch.Tensor, optional): The weight of loss for each
71 | prediction. Defaults to None.
72 | avg_factor (int, optional): Average factor that is used to average
73 | the loss. Defaults to None.
74 | reduction_override (str, optional): The reduction method used to
75 | override the original reduction method of the loss.
76 | Defaults to None.
77 | """
78 | assert reduction_override in (None, 'none', 'mean', 'sum')
79 | reduction = (
80 | reduction_override if reduction_override else self.reduction)
81 | loss_reg = self.loss_weight * gaussian_focal_loss(
82 | pred,
83 | target,
84 | weight,
85 | alpha=self.alpha,
86 | gamma=self.gamma,
87 | reduction=reduction,
88 | avg_factor=avg_factor)
89 | return loss_reg
90 |
--------------------------------------------------------------------------------
/mmdet/models/losses/mse_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from ..builder import LOSSES
5 | from .utils import weighted_loss
6 |
7 |
8 | @weighted_loss
9 | def mse_loss(pred, target):
10 | """Warpper of mse loss."""
11 | return F.mse_loss(pred, target, reduction='none')
12 |
13 |
14 | @LOSSES.register_module()
15 | class MSELoss(nn.Module):
16 | """MSELoss.
17 |
18 | Args:
19 | reduction (str, optional): The method that reduces the loss to a
20 | scalar. Options are "none", "mean" and "sum".
21 | loss_weight (float, optional): The weight of the loss. Defaults to 1.0
22 | """
23 |
24 | def __init__(self, reduction='mean', loss_weight=1.0):
25 | super().__init__()
26 | self.reduction = reduction
27 | self.loss_weight = loss_weight
28 |
29 | def forward(self, pred, target, weight=None, avg_factor=None):
30 | """Forward function of loss.
31 |
32 | Args:
33 | pred (torch.Tensor): The prediction.
34 | target (torch.Tensor): The learning target of the prediction.
35 | weight (torch.Tensor, optional): Weight of the loss for each
36 | prediction. Defaults to None.
37 | avg_factor (int, optional): Average factor that is used to average
38 | the loss. Defaults to None.
39 |
40 | Returns:
41 | torch.Tensor: The calculated loss
42 | """
43 | loss = self.loss_weight * mse_loss(
44 | pred,
45 | target,
46 | weight,
47 | reduction=self.reduction,
48 | avg_factor=avg_factor)
49 | return loss
50 |
--------------------------------------------------------------------------------
/mmdet/models/losses/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn.functional as F
4 |
5 |
6 | def reduce_loss(loss, reduction):
7 | """Reduce loss as specified.
8 |
9 | Args:
10 | loss (Tensor): Elementwise loss tensor.
11 | reduction (str): Options are "none", "mean" and "sum".
12 |
13 | Return:
14 | Tensor: Reduced loss tensor.
15 | """
16 | reduction_enum = F._Reduction.get_enum(reduction)
17 | # none: 0, elementwise_mean:1, sum: 2
18 | if reduction_enum == 0:
19 | return loss
20 | elif reduction_enum == 1:
21 | return loss.mean()
22 | elif reduction_enum == 2:
23 | return loss.sum()
24 |
25 |
26 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
27 | """Apply element-wise weight and reduce loss.
28 |
29 | Args:
30 | loss (Tensor): Element-wise loss.
31 | weight (Tensor): Element-wise weights.
32 | reduction (str): Same as built-in losses of PyTorch.
33 | avg_factor (float): Avarage factor when computing the mean of losses.
34 |
35 | Returns:
36 | Tensor: Processed loss values.
37 | """
38 | # if weight is specified, apply element-wise weight
39 | if weight is not None:
40 | loss = loss * weight
41 |
42 | # if avg_factor is not specified, just reduce the loss
43 | if avg_factor is None:
44 | loss = reduce_loss(loss, reduction)
45 | else:
46 | # if reduction is mean, then average the loss by avg_factor
47 | if reduction == 'mean':
48 | loss = loss.sum() / avg_factor
49 | # if reduction is 'none', then do nothing, otherwise raise an error
50 | elif reduction != 'none':
51 | raise ValueError('avg_factor can not be used with reduction="sum"')
52 | return loss
53 |
54 |
55 | def weighted_loss(loss_func):
56 | """Create a weighted version of a given loss function.
57 |
58 | To use this decorator, the loss function must have the signature like
59 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
60 | element-wise loss without any reduction. This decorator will add weight
61 | and reduction arguments to the function. The decorated function will have
62 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
63 | avg_factor=None, **kwargs)`.
64 |
65 | :Example:
66 |
67 | >>> import torch
68 | >>> @weighted_loss
69 | >>> def l1_loss(pred, target):
70 | >>> return (pred - target).abs()
71 |
72 | >>> pred = torch.Tensor([0, 2, 3])
73 | >>> target = torch.Tensor([1, 1, 1])
74 | >>> weight = torch.Tensor([1, 0, 1])
75 |
76 | >>> l1_loss(pred, target)
77 | tensor(1.3333)
78 | >>> l1_loss(pred, target, weight)
79 | tensor(1.)
80 | >>> l1_loss(pred, target, reduction='none')
81 | tensor([1., 1., 2.])
82 | >>> l1_loss(pred, target, weight, avg_factor=2)
83 | tensor(1.5000)
84 | """
85 |
86 | @functools.wraps(loss_func)
87 | def wrapper(pred,
88 | target,
89 | weight=None,
90 | reduction='mean',
91 | avg_factor=None,
92 | **kwargs):
93 | # get element-wise loss
94 | loss = loss_func(pred, target, **kwargs)
95 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
96 | return loss
97 |
98 | return wrapper
99 |
--------------------------------------------------------------------------------
/mmdet/models/necks/__init__.py:
--------------------------------------------------------------------------------
1 | from .bfp import BFP
2 | from .channel_mapper import ChannelMapper
3 | from .fpn import FPN
4 | from .fpn_carafe import FPN_CARAFE
5 | from .hrfpn import HRFPN
6 | from .nas_fpn import NASFPN
7 | from .nasfcos_fpn import NASFCOS_FPN
8 | from .pafpn import PAFPN
9 | from .rfp import RFP
10 | from .yolo_neck import YOLOV3Neck
11 |
12 | __all__ = [
13 | 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
14 | 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmdet/models/necks/channel_mapper.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from mmcv.cnn import ConvModule, xavier_init
3 |
4 | from ..builder import NECKS
5 |
6 |
7 | @NECKS.register_module()
8 | class ChannelMapper(nn.Module):
9 | r"""Channel Mapper to reduce/increase channels of backbone features.
10 |
11 | This is used to reduce/increase channels of backbone features.
12 |
13 | Args:
14 | in_channels (List[int]): Number of input channels per scale.
15 | out_channels (int): Number of output channels (used at each scale).
16 | kernel_size (int, optional): kernel_size for reducing channels (used
17 | at each scale). Default: 3.
18 | conv_cfg (dict, optional): Config dict for convolution layer.
19 | Default: None.
20 | norm_cfg (dict, optional): Config dict for normalization layer.
21 | Default: None.
22 | act_cfg (dict, optional): Config dict for activation layer in
23 | ConvModule. Default: dict(type='ReLU').
24 |
25 | Example:
26 | >>> import torch
27 | >>> in_channels = [2, 3, 5, 7]
28 | >>> scales = [340, 170, 84, 43]
29 | >>> inputs = [torch.rand(1, c, s, s)
30 | ... for c, s in zip(in_channels, scales)]
31 | >>> self = ChannelMapper(in_channels, 11, 3).eval()
32 | >>> outputs = self.forward(inputs)
33 | >>> for i in range(len(outputs)):
34 | ... print(f'outputs[{i}].shape = {outputs[i].shape}')
35 | outputs[0].shape = torch.Size([1, 11, 340, 340])
36 | outputs[1].shape = torch.Size([1, 11, 170, 170])
37 | outputs[2].shape = torch.Size([1, 11, 84, 84])
38 | outputs[3].shape = torch.Size([1, 11, 43, 43])
39 | """
40 |
41 | def __init__(self,
42 | in_channels,
43 | out_channels,
44 | kernel_size=3,
45 | conv_cfg=None,
46 | norm_cfg=None,
47 | act_cfg=dict(type='ReLU')):
48 | super(ChannelMapper, self).__init__()
49 | assert isinstance(in_channels, list)
50 |
51 | self.convs = nn.ModuleList()
52 | for in_channel in in_channels:
53 | self.convs.append(
54 | ConvModule(
55 | in_channel,
56 | out_channels,
57 | kernel_size,
58 | padding=(kernel_size - 1) // 2,
59 | conv_cfg=conv_cfg,
60 | norm_cfg=norm_cfg,
61 | act_cfg=act_cfg))
62 |
63 | # default init_weights for conv(msra) and norm in ConvModule
64 | def init_weights(self):
65 | """Initialize the weights of ChannelMapper module."""
66 | for m in self.modules():
67 | if isinstance(m, nn.Conv2d):
68 | xavier_init(m, distribution='uniform')
69 |
70 | def forward(self, inputs):
71 | """Forward function."""
72 | assert len(inputs) == len(self.convs)
73 | outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
74 | return tuple(outs)
75 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_roi_head import BaseRoIHead
2 | from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead,
3 | Shared2FCBBoxHead, Shared4Conv1FCBBoxHead)
4 | from .cascade_roi_head import CascadeRoIHead
5 | from .double_roi_head import DoubleHeadRoIHead
6 | from .dynamic_roi_head import DynamicRoIHead
7 | from .grid_roi_head import GridRoIHead
8 | from .htc_roi_head import HybridTaskCascadeRoIHead
9 | from .mask_heads import (CoarseMaskHead, FCNMaskHead, FusedSemanticHead,
10 | GridHead, HTCMaskHead, MaskIoUHead, MaskPointHead)
11 | from .mask_scoring_roi_head import MaskScoringRoIHead
12 | from .pisa_roi_head import PISARoIHead
13 | from .point_rend_roi_head import PointRendRoIHead
14 | from .roi_extractors import SingleRoIExtractor
15 | from .shared_heads import ResLayer
16 | from .standard_roi_head import StandardRoIHead
17 |
18 | from .wsddn_roi_head import WSDDNRoIHead
19 | from .oicr_roi_head import OICRRoIHead
20 | from .wsod2_roi_head import WSOD2RoIHead
21 |
22 | __all__ = [
23 | 'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead',
24 | 'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead',
25 | 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'StandardRoIHead',
26 | 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'FCNMaskHead',
27 | 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 'MaskIoUHead',
28 | 'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead',
29 | 'CoarseMaskHead', 'DynamicRoIHead',
30 | 'WSDDNRoIHead', 'OICRRoIHead', 'WSOD2RoIHead'
31 | ]
32 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/base_roi_head.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import torch.nn as nn
4 |
5 | from ..builder import build_shared_head
6 |
7 |
8 | class BaseRoIHead(nn.Module, metaclass=ABCMeta):
9 | """Base class for RoIHeads."""
10 |
11 | def __init__(self,
12 | bbox_roi_extractor=None,
13 | bbox_head=None,
14 | mask_roi_extractor=None,
15 | mask_head=None,
16 | shared_head=None,
17 | train_cfg=None,
18 | test_cfg=None):
19 | super(BaseRoIHead, self).__init__()
20 | self.train_cfg = train_cfg
21 | self.test_cfg = test_cfg
22 | if shared_head is not None:
23 | self.shared_head = build_shared_head(shared_head)
24 |
25 | if bbox_head is not None:
26 | self.init_bbox_head(bbox_roi_extractor, bbox_head)
27 |
28 | if mask_head is not None:
29 | self.init_mask_head(mask_roi_extractor, mask_head)
30 |
31 | self.init_assigner_sampler()
32 |
33 | @property
34 | def with_bbox(self):
35 | """bool: whether the RoI head contains a `bbox_head`"""
36 | return hasattr(self, 'bbox_head') and self.bbox_head is not None
37 |
38 | @property
39 | def with_mask(self):
40 | """bool: whether the RoI head contains a `mask_head`"""
41 | return hasattr(self, 'mask_head') and self.mask_head is not None
42 |
43 | @property
44 | def with_shared_head(self):
45 | """bool: whether the RoI head contains a `shared_head`"""
46 | return hasattr(self, 'shared_head') and self.shared_head is not None
47 |
48 | @abstractmethod
49 | def init_weights(self, pretrained):
50 | """Initialize the weights in head.
51 |
52 | Args:
53 | pretrained (str, optional): Path to pre-trained weights.
54 | Defaults to None.
55 | """
56 | pass
57 |
58 | @abstractmethod
59 | def init_bbox_head(self):
60 | """Initialize ``bbox_head``"""
61 | pass
62 |
63 | @abstractmethod
64 | def init_mask_head(self):
65 | """Initialize ``mask_head``"""
66 | pass
67 |
68 | @abstractmethod
69 | def init_assigner_sampler(self):
70 | """Initialize assigner and sampler."""
71 | pass
72 |
73 | @abstractmethod
74 | def forward_train(self,
75 | x,
76 | img_meta,
77 | proposal_list,
78 | gt_bboxes,
79 | gt_labels,
80 | gt_bboxes_ignore=None,
81 | gt_masks=None,
82 | **kwargs):
83 | """Forward function during training."""
84 | pass
85 |
86 | async def async_simple_test(self, x, img_meta, **kwargs):
87 | """Asynchronized test function."""
88 | raise NotImplementedError
89 |
90 | def simple_test(self,
91 | x,
92 | proposal_list,
93 | img_meta,
94 | proposals=None,
95 | rescale=False,
96 | **kwargs):
97 | """Test without augmentation."""
98 | pass
99 |
100 | def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):
101 | """Test with augmentations.
102 |
103 | If rescale is False, then returned bboxes and masks will fit the scale
104 | of imgs[0].
105 | """
106 | pass
107 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/bbox_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .bbox_head import BBoxHead
2 | from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead,
3 | Shared4Conv1FCBBoxHead)
4 | from .double_bbox_head import DoubleConvFCBBoxHead
5 | from .sabl_head import SABLHead
6 |
7 | from .wsddn_head import WSDDNHead
8 | from .oicr_head import OICRHead
9 |
10 | __all__ = [
11 | 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
12 | 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead',
13 | 'WSDDNHead', 'OICRHead'
14 | ]
15 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/double_roi_head.py:
--------------------------------------------------------------------------------
1 | from ..builder import HEADS
2 | from .standard_roi_head import StandardRoIHead
3 |
4 |
5 | @HEADS.register_module()
6 | class DoubleHeadRoIHead(StandardRoIHead):
7 | """RoI head for Double Head RCNN.
8 |
9 | https://arxiv.org/abs/1904.06493
10 | """
11 |
12 | def __init__(self, reg_roi_scale_factor, **kwargs):
13 | super(DoubleHeadRoIHead, self).__init__(**kwargs)
14 | self.reg_roi_scale_factor = reg_roi_scale_factor
15 |
16 | def _bbox_forward(self, x, rois):
17 | """Box head forward function used in both training and testing time."""
18 | bbox_cls_feats = self.bbox_roi_extractor(
19 | x[:self.bbox_roi_extractor.num_inputs], rois)
20 | bbox_reg_feats = self.bbox_roi_extractor(
21 | x[:self.bbox_roi_extractor.num_inputs],
22 | rois,
23 | roi_scale_factor=self.reg_roi_scale_factor)
24 | if self.with_shared_head:
25 | bbox_cls_feats = self.shared_head(bbox_cls_feats)
26 | bbox_reg_feats = self.shared_head(bbox_reg_feats)
27 | cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
28 |
29 | bbox_results = dict(
30 | cls_score=cls_score,
31 | bbox_pred=bbox_pred,
32 | bbox_feats=bbox_cls_feats)
33 | return bbox_results
34 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/mask_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .coarse_mask_head import CoarseMaskHead
2 | from .fcn_mask_head import FCNMaskHead
3 | from .fused_semantic_head import FusedSemanticHead
4 | from .grid_head import GridHead
5 | from .htc_mask_head import HTCMaskHead
6 | from .mask_point_head import MaskPointHead
7 | from .maskiou_head import MaskIoUHead
8 |
9 | __all__ = [
10 | 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
11 | 'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from mmcv.cnn import ConvModule, Linear, constant_init, xavier_init
3 | from mmcv.runner import auto_fp16
4 |
5 | from mmdet.models.builder import HEADS
6 | from .fcn_mask_head import FCNMaskHead
7 |
8 |
9 | @HEADS.register_module()
10 | class CoarseMaskHead(FCNMaskHead):
11 | """Coarse mask head used in PointRend.
12 |
13 | Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
14 | the input feature map instead of upsample it.
15 |
16 | Args:
17 | num_convs (int): Number of conv layers in the head. Default: 0.
18 | num_fcs (int): Number of fc layers in the head. Default: 2.
19 | fc_out_channels (int): Number of output channels of fc layer.
20 | Default: 1024.
21 | downsample_factor (int): The factor that feature map is downsampled by.
22 | Default: 2.
23 | """
24 |
25 | def __init__(self,
26 | num_convs=0,
27 | num_fcs=2,
28 | fc_out_channels=1024,
29 | downsample_factor=2,
30 | *arg,
31 | **kwarg):
32 | super(CoarseMaskHead, self).__init__(
33 | *arg, num_convs=num_convs, upsample_cfg=dict(type=None), **kwarg)
34 | self.num_fcs = num_fcs
35 | assert self.num_fcs > 0
36 | self.fc_out_channels = fc_out_channels
37 | self.downsample_factor = downsample_factor
38 | assert self.downsample_factor >= 1
39 | # remove conv_logit
40 | delattr(self, 'conv_logits')
41 |
42 | if downsample_factor > 1:
43 | downsample_in_channels = (
44 | self.conv_out_channels
45 | if self.num_convs > 0 else self.in_channels)
46 | self.downsample_conv = ConvModule(
47 | downsample_in_channels,
48 | self.conv_out_channels,
49 | kernel_size=downsample_factor,
50 | stride=downsample_factor,
51 | padding=0,
52 | conv_cfg=self.conv_cfg,
53 | norm_cfg=self.norm_cfg)
54 | else:
55 | self.downsample_conv = None
56 |
57 | self.output_size = (self.roi_feat_size[0] // downsample_factor,
58 | self.roi_feat_size[1] // downsample_factor)
59 | self.output_area = self.output_size[0] * self.output_size[1]
60 |
61 | last_layer_dim = self.conv_out_channels * self.output_area
62 |
63 | self.fcs = nn.ModuleList()
64 | for i in range(num_fcs):
65 | fc_in_channels = (
66 | last_layer_dim if i == 0 else self.fc_out_channels)
67 | self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
68 | last_layer_dim = self.fc_out_channels
69 | output_channels = self.num_classes * self.output_area
70 | self.fc_logits = Linear(last_layer_dim, output_channels)
71 |
72 | def init_weights(self):
73 | for m in self.fcs.modules():
74 | if isinstance(m, nn.Linear):
75 | xavier_init(m)
76 | constant_init(self.fc_logits, 0.001)
77 |
78 | @auto_fp16()
79 | def forward(self, x):
80 | for conv in self.convs:
81 | x = conv(x)
82 |
83 | if self.downsample_conv is not None:
84 | x = self.downsample_conv(x)
85 |
86 | x = x.flatten(1)
87 | for fc in self.fcs:
88 | x = self.relu(fc(x))
89 | mask_pred = self.fc_logits(x).view(
90 | x.size(0), self.num_classes, *self.output_size)
91 | return mask_pred
92 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/mask_heads/htc_mask_head.py:
--------------------------------------------------------------------------------
1 | from mmcv.cnn import ConvModule
2 |
3 | from mmdet.models.builder import HEADS
4 | from .fcn_mask_head import FCNMaskHead
5 |
6 |
7 | @HEADS.register_module()
8 | class HTCMaskHead(FCNMaskHead):
9 |
10 | def __init__(self, with_conv_res=True, *args, **kwargs):
11 | super(HTCMaskHead, self).__init__(*args, **kwargs)
12 | self.with_conv_res = with_conv_res
13 | if self.with_conv_res:
14 | self.conv_res = ConvModule(
15 | self.conv_out_channels,
16 | self.conv_out_channels,
17 | 1,
18 | conv_cfg=self.conv_cfg,
19 | norm_cfg=self.norm_cfg)
20 |
21 | def init_weights(self):
22 | super(HTCMaskHead, self).init_weights()
23 | if self.with_conv_res:
24 | self.conv_res.init_weights()
25 |
26 | def forward(self, x, res_feat=None, return_logits=True, return_feat=True):
27 | if res_feat is not None:
28 | assert self.with_conv_res
29 | res_feat = self.conv_res(res_feat)
30 | x = x + res_feat
31 | for conv in self.convs:
32 | x = conv(x)
33 | res_feat = x
34 | outs = []
35 | if return_logits:
36 | x = self.upsample(x)
37 | if self.upsample_method == 'deconv':
38 | x = self.relu(x)
39 | mask_pred = self.conv_logits(x)
40 | outs.append(mask_pred)
41 | if return_feat:
42 | outs.append(res_feat)
43 | return outs if len(outs) > 1 else outs[0]
44 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/roi_extractors/__init__.py:
--------------------------------------------------------------------------------
1 | from .generic_roi_extractor import GenericRoIExtractor
2 | from .single_level_roi_extractor import SingleRoIExtractor
3 |
4 | __all__ = [
5 | 'SingleRoIExtractor',
6 | 'GenericRoIExtractor',
7 | ]
8 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import torch
4 | import torch.nn as nn
5 | from mmcv import ops
6 |
7 |
8 | class BaseRoIExtractor(nn.Module, metaclass=ABCMeta):
9 | """Base class for RoI extractor.
10 |
11 | Args:
12 | roi_layer (dict): Specify RoI layer type and arguments.
13 | out_channels (int): Output channels of RoI layers.
14 | featmap_strides (int): Strides of input feature maps.
15 | """
16 |
17 | def __init__(self, roi_layer, out_channels, featmap_strides):
18 | super(BaseRoIExtractor, self).__init__()
19 | self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
20 | self.out_channels = out_channels
21 | self.featmap_strides = featmap_strides
22 | self.fp16_enabled = False
23 |
24 | @property
25 | def num_inputs(self):
26 | """int: Number of input feature maps."""
27 | return len(self.featmap_strides)
28 |
29 | def init_weights(self):
30 | pass
31 |
32 | def build_roi_layers(self, layer_cfg, featmap_strides):
33 | """Build RoI operator to extract feature from each level feature map.
34 |
35 | Args:
36 | layer_cfg (dict): Dictionary to construct and config RoI layer
37 | operation. Options are modules under ``mmcv/ops`` such as
38 | ``RoIAlign``.
39 | featmap_strides (int): The stride of input feature map w.r.t to the
40 | original image size, which would be used to scale RoI
41 | coordinate (original image coordinate system) to feature
42 | coordinate system.
43 |
44 | Returns:
45 | nn.ModuleList: The RoI extractor modules for each level feature
46 | map.
47 | """
48 |
49 | cfg = layer_cfg.copy()
50 | layer_type = cfg.pop('type')
51 | assert hasattr(ops, layer_type)
52 | layer_cls = getattr(ops, layer_type)
53 | roi_layers = nn.ModuleList(
54 | [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
55 | return roi_layers
56 |
57 | def roi_rescale(self, rois, scale_factor):
58 | """Scale RoI coordinates by scale factor.
59 |
60 | Args:
61 | rois (torch.Tensor): RoI (Region of Interest), shape (n, 5)
62 | scale_factor (float): Scale factor that RoI will be multiplied by.
63 |
64 | Returns:
65 | torch.Tensor: Scaled RoI.
66 | """
67 |
68 | cx = (rois[:, 1] + rois[:, 3]) * 0.5
69 | cy = (rois[:, 2] + rois[:, 4]) * 0.5
70 | w = rois[:, 3] - rois[:, 1]
71 | h = rois[:, 4] - rois[:, 2]
72 | new_w = w * scale_factor
73 | new_h = h * scale_factor
74 | x1 = cx - new_w * 0.5
75 | x2 = cx + new_w * 0.5
76 | y1 = cy - new_h * 0.5
77 | y2 = cy + new_h * 0.5
78 | new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
79 | return new_rois
80 |
81 | @abstractmethod
82 | def forward(self, feats, rois, roi_scale_factor=None):
83 | pass
84 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py:
--------------------------------------------------------------------------------
1 | from mmcv.cnn.bricks import build_plugin_layer
2 | from mmcv.runner import force_fp32
3 |
4 | from mmdet.models.builder import ROI_EXTRACTORS
5 | from .base_roi_extractor import BaseRoIExtractor
6 |
7 |
8 | @ROI_EXTRACTORS.register_module()
9 | class GenericRoIExtractor(BaseRoIExtractor):
10 | """Extract RoI features from all level feature maps levels.
11 |
12 | This is the implementation of `A novel Region of Interest Extraction Layer
13 | for Instance Segmentation `_.
14 |
15 | Args:
16 | aggregation (str): The method to aggregate multiple feature maps.
17 | Options are 'sum', 'concat'. Default: 'sum'.
18 | pre_cfg (dict | None): Specify pre-processing modules. Default: None.
19 | post_cfg (dict | None): Specify post-processing modules. Default: None.
20 | kwargs (keyword arguments): Arguments that are the same
21 | as :class:`BaseRoIExtractor`.
22 | """
23 |
24 | def __init__(self,
25 | aggregation='sum',
26 | pre_cfg=None,
27 | post_cfg=None,
28 | **kwargs):
29 | super(GenericRoIExtractor, self).__init__(**kwargs)
30 |
31 | assert aggregation in ['sum', 'concat']
32 |
33 | self.aggregation = aggregation
34 | self.with_post = post_cfg is not None
35 | self.with_pre = pre_cfg is not None
36 | # build pre/post processing modules
37 | if self.with_post:
38 | self.post_module = build_plugin_layer(post_cfg, '_post_module')[1]
39 | if self.with_pre:
40 | self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1]
41 |
42 | @force_fp32(apply_to=('feats', ), out_fp16=True)
43 | def forward(self, feats, rois, roi_scale_factor=None):
44 | """Forward function."""
45 | if len(feats) == 1:
46 | return self.roi_layers[0](feats[0], rois)
47 |
48 | out_size = self.roi_layers[0].output_size
49 | num_levels = len(feats)
50 | roi_feats = feats[0].new_zeros(
51 | rois.size(0), self.out_channels, *out_size)
52 |
53 | # some times rois is an empty tensor
54 | if roi_feats.shape[0] == 0:
55 | return roi_feats
56 |
57 | if roi_scale_factor is not None:
58 | rois = self.roi_rescale(rois, roi_scale_factor)
59 |
60 | # mark the starting channels for concat mode
61 | start_channels = 0
62 | for i in range(num_levels):
63 | roi_feats_t = self.roi_layers[i](feats[i], rois)
64 | end_channels = start_channels + roi_feats_t.size(1)
65 | if self.with_pre:
66 | # apply pre-processing to a RoI extracted from each layer
67 | roi_feats_t = self.pre_module(roi_feats_t)
68 | if self.aggregation == 'sum':
69 | # and sum them all
70 | roi_feats += roi_feats_t
71 | else:
72 | # and concat them along channel dimension
73 | roi_feats[:, start_channels:end_channels] = roi_feats_t
74 | # update channels starting position
75 | start_channels = end_channels
76 | # check if concat channels match at the end
77 | if self.aggregation == 'concat':
78 | assert start_channels == self.out_channels
79 |
80 | if self.with_post:
81 | # apply post-processing before return the result
82 | roi_feats = self.post_module(roi_feats)
83 | return roi_feats
84 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/shared_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .res_layer import ResLayer
2 |
3 | __all__ = ['ResLayer']
4 |
--------------------------------------------------------------------------------
/mmdet/models/roi_heads/shared_heads/res_layer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from mmcv.cnn import constant_init, kaiming_init
3 | from mmcv.runner import auto_fp16, load_checkpoint
4 |
5 | from mmdet.models.backbones import ResNet
6 | from mmdet.models.builder import SHARED_HEADS
7 | from mmdet.models.utils import ResLayer as _ResLayer
8 | from mmdet.utils import get_root_logger
9 |
10 |
11 | @SHARED_HEADS.register_module()
12 | class ResLayer(nn.Module):
13 |
14 | def __init__(self,
15 | depth,
16 | stage=3,
17 | stride=2,
18 | dilation=1,
19 | style='pytorch',
20 | norm_cfg=dict(type='BN', requires_grad=True),
21 | norm_eval=True,
22 | with_cp=False,
23 | dcn=None):
24 | super(ResLayer, self).__init__()
25 | self.norm_eval = norm_eval
26 | self.norm_cfg = norm_cfg
27 | self.stage = stage
28 | self.fp16_enabled = False
29 | block, stage_blocks = ResNet.arch_settings[depth]
30 | stage_block = stage_blocks[stage]
31 | planes = 64 * 2**stage
32 | inplanes = 64 * 2**(stage - 1) * block.expansion
33 |
34 | res_layer = _ResLayer(
35 | block,
36 | inplanes,
37 | planes,
38 | stage_block,
39 | stride=stride,
40 | dilation=dilation,
41 | style=style,
42 | with_cp=with_cp,
43 | norm_cfg=self.norm_cfg,
44 | dcn=dcn)
45 | self.add_module(f'layer{stage + 1}', res_layer)
46 |
47 | def init_weights(self, pretrained=None):
48 | """Initialize the weights in the module.
49 |
50 | Args:
51 | pretrained (str, optional): Path to pre-trained weights.
52 | Defaults to None.
53 | """
54 | if isinstance(pretrained, str):
55 | logger = get_root_logger()
56 | load_checkpoint(self, pretrained, strict=False, logger=logger)
57 | elif pretrained is None:
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | kaiming_init(m)
61 | elif isinstance(m, nn.BatchNorm2d):
62 | constant_init(m, 1)
63 | else:
64 | raise TypeError('pretrained must be a str or None')
65 |
66 | @auto_fp16()
67 | def forward(self, x):
68 | res_layer = getattr(self, f'layer{self.stage + 1}')
69 | out = res_layer(x)
70 | return out
71 |
72 | def train(self, mode=True):
73 | super(ResLayer, self).train(mode)
74 | if self.norm_eval:
75 | for m in self.modules():
76 | if isinstance(m, nn.BatchNorm2d):
77 | m.eval()
78 |
--------------------------------------------------------------------------------
/mmdet/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .gaussian_target import gaussian_radius, gen_gaussian_target
2 | from .res_layer import ResLayer
3 |
4 | __all__ = ['ResLayer', 'gaussian_radius', 'gen_gaussian_target']
5 |
--------------------------------------------------------------------------------
/mmdet/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # This file is added for back-compatibility. Thus, downstream codebase
2 | # could still use and import mmdet.ops.
3 |
4 | # yapf: disable
5 | from mmcv.ops import (ContextBlock, Conv2d, ConvTranspose2d, ConvWS2d,
6 | CornerPool, DeformConv, DeformConvPack, DeformRoIPooling,
7 | DeformRoIPoolingPack, GeneralizedAttention, Linear,
8 | MaskedConv2d, MaxPool2d, ModulatedDeformConv,
9 | ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack,
10 | NonLocal2D, RoIAlign, RoIPool, SAConv2d,
11 | SigmoidFocalLoss, SimpleRoIAlign, batched_nms,
12 | build_plugin_layer, conv_ws_2d, deform_conv,
13 | deform_roi_pooling, get_compiler_version,
14 | get_compiling_cuda_version, modulated_deform_conv, nms,
15 | nms_match, point_sample, rel_roi_point_to_rel_img_point,
16 | roi_align, roi_pool, sigmoid_focal_loss, soft_nms)
17 |
18 | # yapf: enable
19 |
20 | __all__ = [
21 | 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
22 | 'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
23 | 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
24 | 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
25 | 'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss',
26 | 'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D',
27 | 'get_compiler_version', 'get_compiling_cuda_version', 'ConvWS2d',
28 | 'conv_ws_2d', 'build_plugin_layer', 'batched_nms', 'Conv2d',
29 | 'ConvTranspose2d', 'MaxPool2d', 'Linear', 'nms_match', 'CornerPool',
30 | 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
31 | 'SAConv2d'
32 | ]
33 |
--------------------------------------------------------------------------------
/mmdet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .collect_env import collect_env
2 | from .logger import get_root_logger
3 |
4 | __all__ = ['get_root_logger', 'collect_env']
5 |
--------------------------------------------------------------------------------
/mmdet/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import collect_env as collect_base_env
2 | from mmcv.utils import get_git_hash
3 |
4 | import mmdet
5 |
6 |
7 | def collect_env():
8 | """Collect the information of the running environments."""
9 | env_info = collect_base_env()
10 | env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7]
11 | return env_info
12 |
13 |
14 | if __name__ == '__main__':
15 | for name, val in collect_env().items():
16 | print(f'{name}: {val}')
17 |
--------------------------------------------------------------------------------
/mmdet/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from mmcv.utils import get_logger
4 |
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO):
7 | """Get root logger.
8 |
9 | Args:
10 | log_file (str, optional): File path of log. Defaults to None.
11 | log_level (int, optional): The level of logger.
12 | Defaults to logging.INFO.
13 |
14 | Returns:
15 | :obj:`logging.Logger`: The obtained logger
16 | """
17 | logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)
18 |
19 | return logger
20 |
--------------------------------------------------------------------------------
/mmdet/utils/profiling.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import sys
3 | import time
4 |
5 | import torch
6 |
7 | if sys.version_info >= (3, 7):
8 |
9 | @contextlib.contextmanager
10 | def profile_time(trace_name,
11 | name,
12 | enabled=True,
13 | stream=None,
14 | end_stream=None):
15 | """Print time spent by CPU and GPU.
16 |
17 | Useful as a temporary context manager to find sweet spots of code
18 | suitable for async implementation.
19 | """
20 | if (not enabled) or not torch.cuda.is_available():
21 | yield
22 | return
23 | stream = stream if stream else torch.cuda.current_stream()
24 | end_stream = end_stream if end_stream else stream
25 | start = torch.cuda.Event(enable_timing=True)
26 | end = torch.cuda.Event(enable_timing=True)
27 | stream.record_event(start)
28 | try:
29 | cpu_start = time.monotonic()
30 | yield
31 | finally:
32 | cpu_end = time.monotonic()
33 | end_stream.record_event(end)
34 | end.synchronize()
35 | cpu_time = (cpu_end - cpu_start) * 1000
36 | gpu_time = start.elapsed_time(end)
37 | msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms '
38 | msg += f'gpu_time {gpu_time:.2f} ms stream {stream}'
39 | print(msg, end_stream)
40 |
--------------------------------------------------------------------------------
/mmdet/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 |
3 | __version__ = '2.6.0'
4 | short_version = __version__
5 |
6 |
7 | def parse_version_info(version_str):
8 | version_info = []
9 | for x in version_str.split('.'):
10 | if x.isdigit():
11 | version_info.append(int(x))
12 | elif x.find('rc') != -1:
13 | patch_version = x.split('rc')
14 | version_info.append(int(patch_version[0]))
15 | version_info.append(f'rc{patch_version[1]}')
16 | return tuple(version_info)
17 |
18 |
19 | version_info = parse_version_info(__version__)
20 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | addopts = --xdoctest --xdoctest-style=auto
3 | norecursedirs = .git ignore build __pycache__ data docker docs .eggs
4 |
5 | filterwarnings= default
6 | ignore:.*No cfgstr given in Cacher constructor or call.*:Warning
7 | ignore:.*Define the __nice__ method for.*:Warning
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/build.txt
2 | -r requirements/optional.txt
3 | -r requirements/runtime.txt
4 | -r requirements/tests.txt
5 |
--------------------------------------------------------------------------------
/requirements/build.txt:
--------------------------------------------------------------------------------
1 | # These must be installed before building mmdetection
2 | cython
3 | numpy
4 |
--------------------------------------------------------------------------------
/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | recommonmark
2 | sphinx
3 | sphinx_markdown_tables
4 | sphinx_rtd_theme
5 |
--------------------------------------------------------------------------------
/requirements/optional.txt:
--------------------------------------------------------------------------------
1 | albumentations>=0.3.2
2 | cityscapesscripts
3 | imagecorruptions
4 | mmlvis
5 |
--------------------------------------------------------------------------------
/requirements/readthedocs.txt:
--------------------------------------------------------------------------------
1 | mmcv
2 | torch
3 | torchvision
4 |
--------------------------------------------------------------------------------
/requirements/runtime.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | mmpycocotools
3 | numpy
4 | six
5 | terminaltables
6 |
--------------------------------------------------------------------------------
/requirements/tests.txt:
--------------------------------------------------------------------------------
1 | asynctest
2 | codecov
3 | flake8
4 | interrogate
5 | isort==4.3.21
6 | # Note: used for kwarray.group_items, this may be ported to mmcv in the future.
7 | kwarray
8 | pytest
9 | ubelt
10 | xdoctest>=0.10.0
11 | yapf
12 |
--------------------------------------------------------------------------------
/resources/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/resources/architecture.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line_length = 79
3 | multi_line_output = 0
4 | known_standard_library = setuptools
5 | known_first_party = mmdet
6 | known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch
7 | no_lines_before = STDLIB,LOCALFOLDER
8 | default_section = THIRDPARTY
9 |
10 | [yapf]
11 | BASED_ON_STYLE = pep8
12 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
13 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
14 |
--------------------------------------------------------------------------------
/tests/async_benchmark.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import shutil
4 | import urllib
5 |
6 | import mmcv
7 | import torch
8 |
9 | from mmdet.apis import (async_inference_detector, inference_detector,
10 | init_detector, show_result)
11 | from mmdet.utils.contextmanagers import concurrent
12 | from mmdet.utils.profiling import profile_time
13 |
14 |
15 | async def main():
16 | """Benchmark between async and synchronous inference interfaces.
17 |
18 | Sample runs for 20 demo images on K80 GPU, model - mask_rcnn_r50_fpn_1x:
19 |
20 | async sync
21 |
22 | 7981.79 ms 9660.82 ms
23 | 8074.52 ms 9660.94 ms
24 | 7976.44 ms 9406.83 ms
25 |
26 | Async variant takes about 0.83-0.85 of the time of the synchronous
27 | interface.
28 | """
29 | project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
30 |
31 | config_file = os.path.join(project_dir,
32 | 'configs/mask_rcnn_r50_fpn_1x_coco.py')
33 | checkpoint_file = os.path.join(
34 | project_dir, 'checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth')
35 |
36 | if not os.path.exists(checkpoint_file):
37 | url = ('https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection'
38 | '/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth')
39 | print(f'Downloading {url} ...')
40 | local_filename, _ = urllib.request.urlretrieve(url)
41 | os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
42 | shutil.move(local_filename, checkpoint_file)
43 | print(f'Saved as {checkpoint_file}')
44 | else:
45 | print(f'Using existing checkpoint {checkpoint_file}')
46 |
47 | device = 'cuda:0'
48 | model = init_detector(
49 | config_file, checkpoint=checkpoint_file, device=device)
50 |
51 | # queue is used for concurrent inference of multiple images
52 | streamqueue = asyncio.Queue()
53 | # queue size defines concurrency level
54 | streamqueue_size = 4
55 |
56 | for _ in range(streamqueue_size):
57 | streamqueue.put_nowait(torch.cuda.Stream(device=device))
58 |
59 | # test a single image and show the results
60 | img = mmcv.imread(os.path.join(project_dir, 'demo/demo.jpg'))
61 |
62 | # warmup
63 | await async_inference_detector(model, img)
64 |
65 | async def detect(img):
66 | async with concurrent(streamqueue):
67 | return await async_inference_detector(model, img)
68 |
69 | num_of_images = 20
70 | with profile_time('benchmark', 'async'):
71 | tasks = [
72 | asyncio.create_task(detect(img)) for _ in range(num_of_images)
73 | ]
74 | async_results = await asyncio.gather(*tasks)
75 |
76 | with torch.cuda.stream(torch.cuda.default_stream()):
77 | with profile_time('benchmark', 'sync'):
78 | sync_results = [
79 | inference_detector(model, img) for _ in range(num_of_images)
80 | ]
81 |
82 | result_dir = os.path.join(project_dir, 'demo')
83 | show_result(
84 | img,
85 | async_results[0],
86 | model.CLASSES,
87 | score_thr=0.5,
88 | show=False,
89 | out_file=os.path.join(result_dir, 'result_async.jpg'))
90 | show_result(
91 | img,
92 | sync_results[0],
93 | model.CLASSES,
94 | score_thr=0.5,
95 | show=False,
96 | out_file=os.path.join(result_dir, 'result_sync.jpg'))
97 |
98 |
99 | if __name__ == '__main__':
100 | asyncio.run(main())
101 |
--------------------------------------------------------------------------------
/tests/data/coco_sample.json:
--------------------------------------------------------------------------------
1 | {
2 | "images": [
3 | {
4 | "file_name": "fake1.jpg",
5 | "height": 800,
6 | "width": 800,
7 | "id": 0
8 | },
9 | {
10 | "file_name": "fake2.jpg",
11 | "height": 800,
12 | "width": 800,
13 | "id": 1
14 | },
15 | {
16 | "file_name": "fake3.jpg",
17 | "height": 800,
18 | "width": 800,
19 | "id": 2
20 | }
21 | ],
22 | "annotations": [
23 | {
24 | "bbox": [
25 | 0,
26 | 0,
27 | 20,
28 | 20
29 | ],
30 | "area": 400.00,
31 | "score": 1.0,
32 | "category_id": 1,
33 | "id": 1,
34 | "image_id": 0
35 | },
36 | {
37 | "bbox": [
38 | 0,
39 | 0,
40 | 20,
41 | 20
42 | ],
43 | "area": 400.00,
44 | "score": 1.0,
45 | "category_id": 2,
46 | "id": 2,
47 | "image_id": 0
48 | },
49 | {
50 | "bbox": [
51 | 0,
52 | 0,
53 | 20,
54 | 20
55 | ],
56 | "area": 400.00,
57 | "score": 1.0,
58 | "category_id": 1,
59 | "id": 3,
60 | "image_id": 1
61 | }
62 | ],
63 | "categories": [
64 | {
65 | "id": 1,
66 | "name": "bus",
67 | "supercategory": "none"
68 | },
69 | {
70 | "id": 2,
71 | "name": "car",
72 | "supercategory": "none"
73 | }
74 | ],
75 | "licenses": [],
76 | "info": null
77 | }
78 |
--------------------------------------------------------------------------------
/tests/data/color.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/tests/data/color.jpg
--------------------------------------------------------------------------------
/tests/data/gray.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/researchmm/WSOD2/fd6f99401013ed5a66e39cee71a6c2b35580008e/tests/data/gray.jpg
--------------------------------------------------------------------------------
/tests/test_async.py:
--------------------------------------------------------------------------------
1 | """Tests for async interface."""
2 |
3 | import asyncio
4 | import os
5 | import sys
6 |
7 | import asynctest
8 | import mmcv
9 | import torch
10 |
11 | from mmdet.apis import async_inference_detector, init_detector
12 |
13 | if sys.version_info >= (3, 7):
14 | from mmdet.utils.contextmanagers import concurrent
15 |
16 |
17 | class AsyncTestCase(asynctest.TestCase):
18 | use_default_loop = False
19 | forbid_get_event_loop = True
20 |
21 | TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30'))
22 |
23 | def _run_test_method(self, method):
24 | result = method()
25 | if asyncio.iscoroutine(result):
26 | self.loop.run_until_complete(
27 | asyncio.wait_for(result, timeout=self.TEST_TIMEOUT))
28 |
29 |
30 | class MaskRCNNDetector:
31 |
32 | def __init__(self,
33 | model_config,
34 | checkpoint=None,
35 | streamqueue_size=3,
36 | device='cuda:0'):
37 |
38 | self.streamqueue_size = streamqueue_size
39 | self.device = device
40 | # build the model and load checkpoint
41 | self.model = init_detector(
42 | model_config, checkpoint=None, device=self.device)
43 | self.streamqueue = None
44 |
45 | async def init(self):
46 | self.streamqueue = asyncio.Queue()
47 | for _ in range(self.streamqueue_size):
48 | stream = torch.cuda.Stream(device=self.device)
49 | self.streamqueue.put_nowait(stream)
50 |
51 | if sys.version_info >= (3, 7):
52 |
53 | async def apredict(self, img):
54 | if isinstance(img, str):
55 | img = mmcv.imread(img)
56 | async with concurrent(self.streamqueue):
57 | result = await async_inference_detector(self.model, img)
58 | return result
59 |
60 |
61 | class AsyncInferenceTestCase(AsyncTestCase):
62 |
63 | if sys.version_info >= (3, 7):
64 |
65 | async def test_simple_inference(self):
66 | if not torch.cuda.is_available():
67 | import pytest
68 |
69 | pytest.skip('test requires GPU and torch+cuda')
70 |
71 | ori_grad_enabled = torch.is_grad_enabled()
72 | root_dir = os.path.dirname(os.path.dirname(__name__))
73 | model_config = os.path.join(
74 | root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
75 | detector = MaskRCNNDetector(model_config)
76 | await detector.init()
77 | img_path = os.path.join(root_dir, 'demo/demo.jpg')
78 | bboxes, _ = await detector.apredict(img_path)
79 | self.assertTrue(bboxes)
80 | # asy inference detector will hack grad_enabled,
81 | # so restore here to avoid it to influence other tests
82 | torch.set_grad_enabled(ori_grad_enabled)
83 |
--------------------------------------------------------------------------------
/tests/test_coder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mmdet.core.bbox.coder import YOLOBBoxCoder
4 |
5 |
6 | def test_yolo_bbox_coder():
7 | coder = YOLOBBoxCoder()
8 | bboxes = torch.Tensor([[-42., -29., 74., 61.], [-10., -29., 106., 61.],
9 | [22., -29., 138., 61.], [54., -29., 170., 61.]])
10 | pred_bboxes = torch.Tensor([[0.4709, 0.6152, 0.1690, -0.4056],
11 | [0.5399, 0.6653, 0.1162, -0.4162],
12 | [0.4654, 0.6618, 0.1548, -0.4301],
13 | [0.4786, 0.6197, 0.1896, -0.4479]])
14 | grid_size = 32
15 | expected_decode_bboxes = torch.Tensor(
16 | [[-53.6102, -10.3096, 83.7478, 49.6824],
17 | [-15.8700, -8.3901, 114.4236, 50.9693],
18 | [11.1822, -8.0924, 146.6034, 50.4476],
19 | [41.2068, -8.9232, 181.4236, 48.5840]])
20 | assert expected_decode_bboxes.allclose(
21 | coder.decode(bboxes, pred_bboxes, grid_size))
22 |
--------------------------------------------------------------------------------
/tests/test_data/test_formatting.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | from mmcv.utils import build_from_cfg
4 |
5 | from mmdet.datasets.builder import PIPELINES
6 |
7 |
8 | def test_default_format_bundle():
9 | results = dict(
10 | img_prefix=osp.join(osp.dirname(__file__), '../data'),
11 | img_info=dict(filename='color.jpg'))
12 | load = dict(type='LoadImageFromFile')
13 | load = build_from_cfg(load, PIPELINES)
14 | bundle = dict(type='DefaultFormatBundle')
15 | bundle = build_from_cfg(bundle, PIPELINES)
16 | results = load(results)
17 | assert 'pad_shape' not in results
18 | assert 'scale_factor' not in results
19 | assert 'img_norm_cfg' not in results
20 | results = bundle(results)
21 | assert 'pad_shape' in results
22 | assert 'scale_factor' in results
23 | assert 'img_norm_cfg' in results
24 |
--------------------------------------------------------------------------------
/tests/test_data/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from mmdet.datasets import replace_ImageToTensor
4 |
5 |
6 | def test_replace_ImageToTensor():
7 | # with MultiScaleFlipAug
8 | pipelines = [
9 | dict(type='LoadImageFromFile'),
10 | dict(
11 | type='MultiScaleFlipAug',
12 | img_scale=(1333, 800),
13 | flip=False,
14 | transforms=[
15 | dict(type='Resize', keep_ratio=True),
16 | dict(type='RandomFlip'),
17 | dict(type='Normalize'),
18 | dict(type='Pad', size_divisor=32),
19 | dict(type='ImageToTensor', keys=['img']),
20 | dict(type='Collect', keys=['img']),
21 | ])
22 | ]
23 | expected_pipelines = [
24 | dict(type='LoadImageFromFile'),
25 | dict(
26 | type='MultiScaleFlipAug',
27 | img_scale=(1333, 800),
28 | flip=False,
29 | transforms=[
30 | dict(type='Resize', keep_ratio=True),
31 | dict(type='RandomFlip'),
32 | dict(type='Normalize'),
33 | dict(type='Pad', size_divisor=32),
34 | dict(type='DefaultFormatBundle'),
35 | dict(type='Collect', keys=['img']),
36 | ])
37 | ]
38 | with pytest.warns(UserWarning):
39 | assert expected_pipelines == replace_ImageToTensor(pipelines)
40 |
41 | # without MultiScaleFlipAug
42 | pipelines = [
43 | dict(type='LoadImageFromFile'),
44 | dict(type='Resize', keep_ratio=True),
45 | dict(type='RandomFlip'),
46 | dict(type='Normalize'),
47 | dict(type='Pad', size_divisor=32),
48 | dict(type='ImageToTensor', keys=['img']),
49 | dict(type='Collect', keys=['img']),
50 | ]
51 | expected_pipelines = [
52 | dict(type='LoadImageFromFile'),
53 | dict(type='Resize', keep_ratio=True),
54 | dict(type='RandomFlip'),
55 | dict(type='Normalize'),
56 | dict(type='Pad', size_divisor=32),
57 | dict(type='DefaultFormatBundle'),
58 | dict(type='Collect', keys=['img']),
59 | ]
60 | with pytest.warns(UserWarning):
61 | assert expected_pipelines == replace_ImageToTensor(pipelines)
62 |
--------------------------------------------------------------------------------
/tests/test_version.py:
--------------------------------------------------------------------------------
1 | from mmdet import digit_version
2 |
3 |
4 | def test_version_check():
5 | assert digit_version('1.0.5') > digit_version('1.0.5rc0')
6 | assert digit_version('1.0.5') > digit_version('1.0.4rc0')
7 | assert digit_version('1.0.5') > digit_version('1.0rc0')
8 | assert digit_version('1.0.0') > digit_version('0.6.2')
9 | assert digit_version('1.0.0') > digit_version('0.2.16')
10 | assert digit_version('1.0.5rc0') > digit_version('1.0.0rc0')
11 | assert digit_version('1.0.0rc1') > digit_version('1.0.0rc0')
12 | assert digit_version('1.0.0rc2') > digit_version('1.0.0rc0')
13 | assert digit_version('1.0.0rc2') > digit_version('1.0.0rc1')
14 | assert digit_version('1.0.1rc1') > digit_version('1.0.0rc1')
15 | assert digit_version('1.0.0') > digit_version('1.0.0rc1')
16 |
--------------------------------------------------------------------------------
/tools/benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import torch
5 | from mmcv import Config
6 | from mmcv.cnn import fuse_conv_bn
7 | from mmcv.parallel import MMDataParallel
8 | from mmcv.runner import load_checkpoint, wrap_fp16_model
9 |
10 | from mmdet.datasets import (build_dataloader, build_dataset,
11 | replace_ImageToTensor)
12 | from mmdet.models import build_detector
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='MMDet benchmark a model')
17 | parser.add_argument('config', help='test config file path')
18 | parser.add_argument('checkpoint', help='checkpoint file')
19 | parser.add_argument(
20 | '--log-interval', default=50, help='interval of logging')
21 | parser.add_argument(
22 | '--fuse-conv-bn',
23 | action='store_true',
24 | help='Whether to fuse conv and bn, this will slightly increase'
25 | 'the inference speed')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main():
31 | args = parse_args()
32 |
33 | cfg = Config.fromfile(args.config)
34 | # import modules from string list.
35 | if cfg.get('custom_imports', None):
36 | from mmcv.utils import import_modules_from_strings
37 | import_modules_from_strings(**cfg['custom_imports'])
38 | # set cudnn_benchmark
39 | if cfg.get('cudnn_benchmark', False):
40 | torch.backends.cudnn.benchmark = True
41 | cfg.model.pretrained = None
42 | cfg.data.test.test_mode = True
43 |
44 | # build the dataloader
45 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
46 | if samples_per_gpu > 1:
47 | # Replace 'ImageToTensor' to 'DefaultFormatBundle'
48 | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
49 | dataset = build_dataset(cfg.data.test)
50 | data_loader = build_dataloader(
51 | dataset,
52 | samples_per_gpu=1,
53 | workers_per_gpu=cfg.data.workers_per_gpu,
54 | dist=False,
55 | shuffle=False)
56 |
57 | # build the model and load checkpoint
58 | model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
59 | fp16_cfg = cfg.get('fp16', None)
60 | if fp16_cfg is not None:
61 | wrap_fp16_model(model)
62 | load_checkpoint(model, args.checkpoint, map_location='cpu')
63 | if args.fuse_conv_bn:
64 | model = fuse_conv_bn(model)
65 |
66 | model = MMDataParallel(model, device_ids=[0])
67 |
68 | model.eval()
69 |
70 | # the first several iterations may be very slow so skip them
71 | num_warmup = 5
72 | pure_inf_time = 0
73 |
74 | # benchmark with 2000 image and take the average
75 | for i, data in enumerate(data_loader):
76 |
77 | torch.cuda.synchronize()
78 | start_time = time.perf_counter()
79 |
80 | with torch.no_grad():
81 | model(return_loss=False, rescale=True, **data)
82 |
83 | torch.cuda.synchronize()
84 | elapsed = time.perf_counter() - start_time
85 |
86 | if i >= num_warmup:
87 | pure_inf_time += elapsed
88 | if (i + 1) % args.log_interval == 0:
89 | fps = (i + 1 - num_warmup) / pure_inf_time
90 | print(f'Done image [{i + 1:<3}/ 2000], fps: {fps:.1f} img / s')
91 |
92 | if (i + 1) == 2000:
93 | pure_inf_time += elapsed
94 | fps = (i + 1 - num_warmup) / pure_inf_time
95 | print(f'Overall fps: {fps:.1f} img / s')
96 | break
97 |
98 |
99 | if __name__ == '__main__':
100 | main()
101 |
--------------------------------------------------------------------------------
/tools/browse_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import mmcv
6 | from mmcv import Config
7 |
8 | from mmdet.datasets.builder import build_dataset
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(description='Browse a dataset')
13 | parser.add_argument('config', help='train config file path')
14 | parser.add_argument(
15 | '--skip-type',
16 | type=str,
17 | nargs='+',
18 | default=['DefaultFormatBundle', 'Normalize', 'Collect'],
19 | help='skip some useless pipeline')
20 | parser.add_argument(
21 | '--output-dir',
22 | default=None,
23 | type=str,
24 | help='If there is no display interface, you can save it')
25 | parser.add_argument('--not-show', default=False, action='store_true')
26 | parser.add_argument(
27 | '--show-interval',
28 | type=int,
29 | default=999,
30 | help='the interval of show (ms)')
31 | args = parser.parse_args()
32 | return args
33 |
34 |
35 | def retrieve_data_cfg(config_path, skip_type):
36 | cfg = Config.fromfile(config_path)
37 | train_data_cfg = cfg.data.train
38 | train_data_cfg['pipeline'] = [
39 | x for x in train_data_cfg.pipeline if x['type'] not in skip_type
40 | ]
41 |
42 | return cfg
43 |
44 |
45 | def main():
46 | args = parse_args()
47 | cfg = retrieve_data_cfg(args.config, args.skip_type)
48 |
49 | dataset = build_dataset(cfg.data.train)
50 |
51 | progress_bar = mmcv.ProgressBar(len(dataset))
52 | for item in dataset:
53 | filename = os.path.join(args.output_dir,
54 | Path(item['filename']).name
55 | ) if args.output_dir is not None else None
56 | mmcv.imshow_det_bboxes(
57 | item['img'],
58 | item['gt_bboxes'],
59 | item['gt_labels'],
60 | class_names=dataset.CLASSES,
61 | show=not args.not_show,
62 | out_file=filename,
63 | wait_time=args.show_interval)
64 | progress_bar.update()
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------
/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 |
8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
9 | /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
11 |
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-29500}
6 |
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
10 |
--------------------------------------------------------------------------------
/tools/eval_metric.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mmcv
4 | from mmcv import Config, DictAction
5 |
6 | from mmdet.datasets import build_dataset
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='Evaluate metric of the '
11 | 'results saved in pkl format')
12 | parser.add_argument('config', help='Config of the model')
13 | parser.add_argument('pkl_results', help='Results in pickle format')
14 | parser.add_argument(
15 | '--format-only',
16 | action='store_true',
17 | help='Format the output results without perform evaluation. It is'
18 | 'useful when you want to format the result to a specific format and '
19 | 'submit it to the test server')
20 | parser.add_argument(
21 | '--eval',
22 | type=str,
23 | nargs='+',
24 | help='Evaluation metrics, which depends on the dataset, e.g., "bbox",'
25 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
26 | parser.add_argument(
27 | '--cfg-options',
28 | nargs='+',
29 | action=DictAction,
30 | help='override some settings in the used config, the key-value pair '
31 | 'in xxx=yyy format will be merged into config file.')
32 | parser.add_argument(
33 | '--eval-options',
34 | nargs='+',
35 | action=DictAction,
36 | help='custom options for evaluation, the key-value pair in xxx=yyy '
37 | 'format will be kwargs for dataset.evaluate() function')
38 | args = parser.parse_args()
39 | return args
40 |
41 |
42 | def main():
43 | args = parse_args()
44 |
45 | cfg = Config.fromfile(args.config)
46 | assert args.eval or args.format_only, (
47 | 'Please specify at least one operation (eval/format the results) with '
48 | 'the argument "--eval", "--format-only"')
49 | if args.eval and args.format_only:
50 | raise ValueError('--eval and --format_only cannot be both specified')
51 |
52 | if args.cfg_options is not None:
53 | cfg.merge_from_dict(args.cfg_options)
54 | cfg.data.test.test_mode = True
55 |
56 | dataset = build_dataset(cfg.data.test)
57 | outputs = mmcv.load(args.pkl_results)
58 |
59 | kwargs = {} if args.eval_options is None else args.eval_options
60 | if args.format_only:
61 | dataset.format_results(outputs, **kwargs)
62 | if args.eval:
63 | eval_kwargs = cfg.get('evaluation', {}).copy()
64 | # hard-code way to remove EvalHook args
65 | for key in [
66 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
67 | 'rule'
68 | ]:
69 | eval_kwargs.pop(key, None)
70 | eval_kwargs.update(dict(metric=args.eval, **kwargs))
71 | print(dataset.evaluate(outputs, **eval_kwargs))
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
76 |
--------------------------------------------------------------------------------
/tools/get_flops.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from mmcv import Config
5 |
6 | from mmdet.models import build_detector
7 |
8 | try:
9 | from mmcv.cnn import get_model_complexity_info
10 | except ImportError:
11 | raise ImportError('Please upgrade mmcv to >0.6.2')
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(description='Train a detector')
16 | parser.add_argument('config', help='train config file path')
17 | parser.add_argument(
18 | '--shape',
19 | type=int,
20 | nargs='+',
21 | default=[1280, 800],
22 | help='input image size')
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | def main():
28 |
29 | args = parse_args()
30 |
31 | if len(args.shape) == 1:
32 | input_shape = (3, args.shape[0], args.shape[0])
33 | elif len(args.shape) == 2:
34 | input_shape = (3, ) + tuple(args.shape)
35 | else:
36 | raise ValueError('invalid input shape')
37 |
38 | cfg = Config.fromfile(args.config)
39 | # import modules from string list.
40 | if cfg.get('custom_imports', None):
41 | from mmcv.utils import import_modules_from_strings
42 | import_modules_from_strings(**cfg['custom_imports'])
43 |
44 | model = build_detector(
45 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
46 | if torch.cuda.is_available():
47 | model.cuda()
48 | model.eval()
49 |
50 | if hasattr(model, 'forward_dummy'):
51 | model.forward = model.forward_dummy
52 | else:
53 | raise NotImplementedError(
54 | 'FLOPs counter is currently not currently supported with {}'.
55 | format(model.__class__.__name__))
56 |
57 | flops, params = get_model_complexity_info(model, input_shape)
58 | split_line = '=' * 30
59 | print(f'{split_line}\nInput shape: {input_shape}\n'
60 | f'Flops: {flops}\nParams: {params}\n{split_line}')
61 | print('!!!Please be cautious if you use the results in papers. '
62 | 'You may need to check if all ops are supported and verify that the '
63 | 'flops computation is correct.')
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/tools/prepare.sh:
--------------------------------------------------------------------------------
1 | # download vgg16 imagenet pre-trained weights
2 | mkdir -p pretrain
3 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=10Vh2qFmGucO-9DZ3eY3HAvcAmtPFcFg2' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=10Vh2qFmGucO-9DZ3eY3HAvcAmtPFcFg2" -O pretrain/vgg16.pth && rm -rf /tmp/cookies.txt
4 |
5 | # download pascal voc selective search region proposals
6 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-1EJ3Mm7KoXwaurYx4zpPcerkKbIAqU-' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-1EJ3Mm7KoXwaurYx4zpPcerkKbIAqU-" -O data/voc_selective_search.zip && rm -rf /tmp/cookies.txt
7 | cd data && unzip voc_selective_search.zip && mv voc_selective_search/voc_2007* VOCdevkit/VOC2007/ && mv voc_selective_search/voc_2012* VOCdevkit/VOC2012/ && rm voc_selective_search.zip && rm -rf voc_selective_search
8 |
9 | # download pascal voc super pixels
10 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1tItyPAUz16iXOyIHpVfAmWbKfwQzkzoU' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1tItyPAUz16iXOyIHpVfAmWbKfwQzkzoU" -O data/voc_super_pixels.zip && rm -rf /tmp/cookies.txt
11 | cd data && unzip voc_super_pixels && mv voc_super_pixels/superpixel2007 VOCdevkit/VOC2007/SuperPixels && mv voc_super_pixels/superpixel2012 VOCdevkit/VOC2012/SuperPixels && rm voc_super_pixels.zip && rm -rf voc_super_pixels
12 |
13 |
--------------------------------------------------------------------------------
/tools/print_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from mmcv import Config, DictAction
4 |
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser(description='Print the whole config')
8 | parser.add_argument('config', help='config file path')
9 | parser.add_argument(
10 | '--options', nargs='+', action=DictAction, help='arguments in dict')
11 | args = parser.parse_args()
12 |
13 | return args
14 |
15 |
16 | def main():
17 | args = parse_args()
18 |
19 | cfg = Config.fromfile(args.config)
20 | if args.options is not None:
21 | cfg.merge_from_dict(args.options)
22 | print(f'Config:\n{cfg.pretty_text}')
23 |
24 |
25 | if __name__ == '__main__':
26 | main()
27 |
--------------------------------------------------------------------------------
/tools/publish_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 |
4 | import torch
5 |
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser(
9 | description='Process a checkpoint to be published')
10 | parser.add_argument('in_file', help='input checkpoint filename')
11 | parser.add_argument('out_file', help='output checkpoint filename')
12 | args = parser.parse_args()
13 | return args
14 |
15 |
16 | def process_checkpoint(in_file, out_file):
17 | checkpoint = torch.load(in_file, map_location='cpu')
18 | # remove optimizer for smaller file size
19 | if 'optimizer' in checkpoint:
20 | del checkpoint['optimizer']
21 | # if it is necessary to remove some sensitive data in checkpoint['meta'],
22 | # add the code here.
23 | torch.save(checkpoint, out_file)
24 | sha = subprocess.check_output(['sha256sum', out_file]).decode()
25 | if out_file.endswith('.pth'):
26 | out_file_name = out_file[:-4]
27 | else:
28 | out_file_name = out_file
29 | final_file = out_file_name + f'-{sha[:8]}.pth'
30 | subprocess.Popen(['mv', out_file, final_file])
31 |
32 |
33 | def main():
34 | args = parse_args()
35 | process_checkpoint(args.in_file, args.out_file)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/tools/regnet2mmdet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import OrderedDict
3 |
4 | import torch
5 |
6 |
7 | def convert_stem(model_key, model_weight, state_dict, converted_names):
8 | new_key = model_key.replace('stem.conv', 'conv1')
9 | new_key = new_key.replace('stem.bn', 'bn1')
10 | state_dict[new_key] = model_weight
11 | converted_names.add(model_key)
12 | print(f'Convert {model_key} to {new_key}')
13 |
14 |
15 | def convert_head(model_key, model_weight, state_dict, converted_names):
16 | new_key = model_key.replace('head.fc', 'fc')
17 | state_dict[new_key] = model_weight
18 | converted_names.add(model_key)
19 | print(f'Convert {model_key} to {new_key}')
20 |
21 |
22 | def convert_reslayer(model_key, model_weight, state_dict, converted_names):
23 | split_keys = model_key.split('.')
24 | layer, block, module = split_keys[:3]
25 | block_id = int(block[1:])
26 | layer_name = f'layer{int(layer[1:])}'
27 | block_name = f'{block_id - 1}'
28 |
29 | if block_id == 1 and module == 'bn':
30 | new_key = f'{layer_name}.{block_name}.downsample.1.{split_keys[-1]}'
31 | elif block_id == 1 and module == 'proj':
32 | new_key = f'{layer_name}.{block_name}.downsample.0.{split_keys[-1]}'
33 | elif module == 'f':
34 | if split_keys[3] == 'a_bn':
35 | module_name = 'bn1'
36 | elif split_keys[3] == 'b_bn':
37 | module_name = 'bn2'
38 | elif split_keys[3] == 'c_bn':
39 | module_name = 'bn3'
40 | elif split_keys[3] == 'a':
41 | module_name = 'conv1'
42 | elif split_keys[3] == 'b':
43 | module_name = 'conv2'
44 | elif split_keys[3] == 'c':
45 | module_name = 'conv3'
46 | new_key = f'{layer_name}.{block_name}.{module_name}.{split_keys[-1]}'
47 | else:
48 | raise ValueError(f'Unsupported conversion of key {model_key}')
49 | print(f'Convert {model_key} to {new_key}')
50 | state_dict[new_key] = model_weight
51 | converted_names.add(model_key)
52 |
53 |
54 | def convert(src, dst):
55 | """Convert keys in pycls pretrained RegNet models to mmdet style."""
56 | # load caffe model
57 | regnet_model = torch.load(src)
58 | blobs = regnet_model['model_state']
59 | # convert to pytorch style
60 | state_dict = OrderedDict()
61 | converted_names = set()
62 | for key, weight in blobs.items():
63 | if 'stem' in key:
64 | convert_stem(key, weight, state_dict, converted_names)
65 | elif 'head' in key:
66 | convert_head(key, weight, state_dict, converted_names)
67 | elif key.startswith('s'):
68 | convert_reslayer(key, weight, state_dict, converted_names)
69 |
70 | # check if all layers are converted
71 | for key in blobs:
72 | if key not in converted_names:
73 | print(f'not converted: {key}')
74 | # save checkpoint
75 | checkpoint = dict()
76 | checkpoint['state_dict'] = state_dict
77 | torch.save(checkpoint, dst)
78 |
79 |
80 | def main():
81 | parser = argparse.ArgumentParser(description='Convert model keys')
82 | parser.add_argument('src', help='src detectron model path')
83 | parser.add_argument('dst', help='save path')
84 | args = parser.parse_args()
85 | convert(args.src, args.dst)
86 |
87 |
88 | if __name__ == '__main__':
89 | main()
90 |
--------------------------------------------------------------------------------
/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CHECKPOINT=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | WORK_DIR=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | SRUN_ARGS=${SRUN_ARGS:-""}
13 | PY_ARGS=${@:5}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------