├── .github ├── FUNDING.yml └── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ └── feature_request.md ├── .gitignore ├── LICENSE ├── README.md ├── avg_checkpoints.py ├── clean_checkpoint.py ├── distributed_train.sh ├── effdet ├── __init__.py ├── anchors.py ├── bench.py ├── config │ ├── __init__.py │ ├── config_utils.py │ ├── fpn_config.py │ ├── model_config.py │ └── train_config.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── dataset_config.py │ ├── dataset_factory.py │ ├── input_config.py │ ├── loader.py │ ├── parsers │ │ ├── __init__.py │ │ ├── parser.py │ │ ├── parser_coco.py │ │ ├── parser_config.py │ │ ├── parser_factory.py │ │ ├── parser_open_images.py │ │ └── parser_voc.py │ ├── random_erasing.py │ └── transforms.py ├── distributed.py ├── efficientdet.py ├── evaluation │ ├── README.md │ ├── __init__.py │ ├── detection_evaluator.py │ ├── fields.py │ ├── metrics.py │ ├── np_box_list.py │ ├── np_mask_list.py │ ├── object_detection_evaluation.py │ └── per_image_evaluation.py ├── evaluator.py ├── factory.py ├── helpers.py ├── loss.py ├── object_detection │ ├── README.md │ ├── __init__.py │ ├── argmax_matcher.py │ ├── box_coder.py │ ├── box_list.py │ ├── matcher.py │ ├── region_similarity_calculator.py │ └── target_assigner.py ├── soft_nms.py └── version.py ├── requirements.txt ├── setup.py ├── train.py └── validate.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: rwightman 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting 4 | features, I hope to activate the discussion forum soon for asking questions or seeking 5 | help from the community. 6 | title: "[BUG] Issue title..." 7 | labels: bug 8 | assignees: '' 9 | 10 | --- 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 1. 18 | 2. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Windows, Ubuntu 18.04] 28 | - This repository version [e.g. pypi 0.3.1 or branch/commit ref] 29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Community Discussions 4 | url: https://github.com/rwightman/efficientdet-pytorch/discussions 5 | about: Issues are exclusively for concrete features and bugs. Please use Discussions for questions, sharing, other ideas. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project. Issues are for reporting bugs or requesting 4 | features, I hope to activate the discussion forum soon for asking questions or seeking 5 | help from the community. 6 | title: "[FEATURE] Feature title..." 7 | labels: enhancement 8 | assignees: '' 9 | 10 | --- 11 | 12 | **Is your feature request related to a problem? Please describe.** 13 | A clear and concise description of what the problem is. 14 | 15 | **Describe the solution you'd like** 16 | A clear and concise description of what you want to happen. 17 | 18 | **Describe alternatives you've considered** 19 | A clear and concise description of any alternative solutions or features you've considered. 20 | 21 | **Additional context** 22 | Add any other context or screenshots about the feature request here. 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm 132 | .idea 133 | 134 | # PyTorch weights 135 | *.tar 136 | *.pth 137 | *.gz 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Ross Wightman 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /avg_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ Checkpoint Averaging Script 3 | 4 | This script averages all model weights for checkpoints in specified path that match 5 | the specified filter wildcard. All checkpoints must be from the exact same model. 6 | 7 | For any hope of decent results, the checkpoints should be from the same or child 8 | (via resumes) training session. This can be viewed as similar to maintaining running 9 | EMA (exponential moving average) of the model weights or performing SWA (stochastic 10 | weight averaging), but post-training. 11 | 12 | Hacked together by Ross Wightman (https://github.com/rwightman) 13 | """ 14 | import torch 15 | import argparse 16 | import os 17 | import glob 18 | import hashlib 19 | from timm.models.helpers import load_state_dict 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') 22 | parser.add_argument('--input', default='', type=str, metavar='PATH', 23 | help='path to base input folder containing checkpoints') 24 | parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', 25 | help='checkpoint filter (path wildcard)') 26 | parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', 27 | help='output filename') 28 | parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', 29 | help='Force not using ema version of weights (if present)') 30 | parser.add_argument('--descending', dest='descending', action='store_true', 31 | help='Set if eval metric is descending (like loss)') 32 | parser.add_argument('--no-sort', dest='no_sort', action='store_true', 33 | help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') 34 | parser.add_argument('-n', type=int, default=10, metavar='N', 35 | help='Number of checkpoints to average') 36 | 37 | 38 | def checkpoint_metric(checkpoint_path): 39 | if not checkpoint_path or not os.path.isfile(checkpoint_path): 40 | return {} 41 | print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path)) 42 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 43 | metric = None 44 | if 'metric' in checkpoint: 45 | metric = checkpoint['metric'] 46 | return metric 47 | 48 | 49 | def main(): 50 | args = parser.parse_args() 51 | # by default use the EMA weights (if present) 52 | args.use_ema = not args.no_use_ema 53 | # by default sort by checkpoint metric (if present) and avg top n checkpoints 54 | args.sort = not args.no_sort 55 | 56 | if os.path.exists(args.output): 57 | print("Error: Output filename ({}) already exists.".format(args.output)) 58 | exit(1) 59 | 60 | pattern = args.input 61 | if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep): 62 | pattern += os.path.sep 63 | pattern += args.filter 64 | checkpoints = glob.glob(pattern, recursive=True) 65 | if not checkpoints: 66 | print("Error: No checkpoints to average.") 67 | exit(1) 68 | 69 | if args.sort: 70 | checkpoint_metrics = [] 71 | for c in checkpoints: 72 | metric = checkpoint_metric(c) 73 | if metric is not None: 74 | checkpoint_metrics.append((metric, c)) 75 | checkpoint_metrics = list(sorted(checkpoint_metrics, reverse=not args.descending)) 76 | checkpoint_metrics = checkpoint_metrics[:args.n] 77 | print("Selected checkpoints:") 78 | [print(m, c) for m, c in checkpoint_metrics] 79 | avg_checkpoints = [c for m, c in checkpoint_metrics] 80 | else: 81 | avg_checkpoints = checkpoints 82 | print("Selected checkpoints:") 83 | [print(c) for c in checkpoints] 84 | 85 | avg_state_dict = {} 86 | avg_counts = {} 87 | for c in avg_checkpoints: 88 | new_state_dict = load_state_dict(c, args.use_ema) 89 | if not new_state_dict: 90 | print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) 91 | continue 92 | 93 | for k, v in new_state_dict.items(): 94 | if k not in avg_state_dict: 95 | avg_state_dict[k] = v.clone().to(dtype=torch.float64) 96 | avg_counts[k] = 1 97 | else: 98 | avg_state_dict[k] += v.to(dtype=torch.float64) 99 | avg_counts[k] += 1 100 | 101 | for k, v in avg_state_dict.items(): 102 | v.div_(avg_counts[k]) 103 | 104 | # float32 overflow seems unlikely based on weights seen to date, but who knows 105 | float32_info = torch.finfo(torch.float32) 106 | final_state_dict = {} 107 | for k, v in avg_state_dict.items(): 108 | v = v.clamp(float32_info.min, float32_info.max) 109 | final_state_dict[k] = v.to(dtype=torch.float32) 110 | 111 | try: 112 | torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) 113 | except: 114 | torch.save(final_state_dict, args.output) 115 | 116 | with open(args.output, 'rb') as f: 117 | sha_hash = hashlib.sha256(f.read()).hexdigest() 118 | print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /clean_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ Checkpoint Cleaning Script 3 | 4 | Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc. 5 | and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256 6 | calculation for model zoo compatibility. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 9 | """ 10 | import torch 11 | import argparse 12 | import os 13 | import hashlib 14 | import shutil 15 | from collections import OrderedDict 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') 18 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 19 | help='path to latest checkpoint (default: none)') 20 | parser.add_argument('--output', default='', type=str, metavar='PATH', 21 | help='output path') 22 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 23 | help='use ema version of weights if present') 24 | parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', 25 | help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') 26 | 27 | _TEMP_NAME = './_checkpoint.pth' 28 | 29 | 30 | def main(): 31 | args = parser.parse_args() 32 | 33 | if os.path.exists(args.output): 34 | print("Error: Output filename ({}) already exists.".format(args.output)) 35 | exit(1) 36 | 37 | # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save 38 | if args.checkpoint and os.path.isfile(args.checkpoint): 39 | print("=> Loading checkpoint '{}'".format(args.checkpoint)) 40 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 41 | 42 | new_state_dict = OrderedDict() 43 | if isinstance(checkpoint, dict): 44 | state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict' 45 | if state_dict_key in checkpoint: 46 | state_dict = checkpoint[state_dict_key] 47 | else: 48 | state_dict = checkpoint 49 | else: 50 | assert False 51 | for k, v in state_dict.items(): 52 | if args.clean_aux_bn and 'aux_bn' in k: 53 | # If all aux_bn keys are removed, the SplitBN layers will end up as normal and 54 | # load with the unmodified model using BatchNorm2d. 55 | continue 56 | name = k[7:] if k.startswith('module') else k 57 | new_state_dict[name] = v 58 | print("=> Loaded state_dict from '{}'".format(args.checkpoint)) 59 | 60 | try: 61 | torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) 62 | except: 63 | torch.save(new_state_dict, _TEMP_NAME) 64 | 65 | with open(_TEMP_NAME, 'rb') as f: 66 | sha_hash = hashlib.sha256(f.read()).hexdigest() 67 | 68 | if args.output: 69 | checkpoint_root, checkpoint_base = os.path.split(args.output) 70 | checkpoint_base = os.path.splitext(checkpoint_base)[0] 71 | else: 72 | checkpoint_root = '' 73 | checkpoint_base = os.path.splitext(args.checkpoint)[0] 74 | final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' 75 | shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) 76 | print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) 77 | else: 78 | print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --nproc-per-node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /effdet/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficientdet import EfficientDet 2 | from .bench import DetBenchPredict, DetBenchTrain, unwrap_bench 3 | from .data import create_dataset, create_loader, create_parser, DetectionDatset, SkipSubset 4 | from .evaluator import CocoEvaluator, PascalEvaluator, OpenImagesEvaluator, create_evaluator 5 | from .config import get_efficientdet_config, default_detection_model_configs 6 | from .factory import create_model, create_model_from_config 7 | from .helpers import load_checkpoint, load_pretrained 8 | -------------------------------------------------------------------------------- /effdet/bench.py: -------------------------------------------------------------------------------- 1 | """ PyTorch EfficientDet support benches 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | from typing import Optional, Dict, List 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .anchors import Anchors, AnchorLabeler, generate_detections 11 | from .loss import DetectionLoss 12 | 13 | # try: 14 | # torch.div(torch.ones(1), torch.ones(1), rounding_mode='floor') 15 | # has_rounding_mode = True 16 | # except TypeError: 17 | # has_rounding_mode = False 18 | 19 | 20 | def _post_process( 21 | cls_outputs: List[torch.Tensor], 22 | box_outputs: List[torch.Tensor], 23 | num_levels: int, 24 | num_classes: int, 25 | max_detection_points: int = 5000, 26 | ): 27 | """Selects top-k predictions. 28 | 29 | Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet 30 | and optimized for PyTorch. 31 | 32 | Args: 33 | cls_outputs: an OrderDict with keys representing levels and values 34 | representing logits in [batch_size, height, width, num_anchors]. 35 | 36 | box_outputs: an OrderDict with keys representing levels and values 37 | representing box regression targets in [batch_size, height, width, num_anchors * 4]. 38 | 39 | num_levels (int): number of feature levels 40 | 41 | num_classes (int): number of output classes 42 | """ 43 | batch_size = cls_outputs[0].shape[0] 44 | cls_outputs_all = torch.cat([ 45 | cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes]) 46 | for level in range(num_levels)], 1) 47 | 48 | box_outputs_all = torch.cat([ 49 | box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4]) 50 | for level in range(num_levels)], 1) 51 | 52 | _, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points) 53 | # FIXME change someday, will have to live with annoying warnings for a while as testing impl breaks torchscript 54 | # indices_all = torch.div(cls_topk_indices_all, num_classes, rounding_mode='trunc') 55 | indices_all = cls_topk_indices_all // num_classes 56 | classes_all = cls_topk_indices_all % num_classes 57 | 58 | box_outputs_all_after_topk = torch.gather( 59 | box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4)) 60 | 61 | cls_outputs_all_after_topk = torch.gather( 62 | cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes)) 63 | cls_outputs_all_after_topk = torch.gather( 64 | cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2)) 65 | 66 | return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all 67 | 68 | 69 | @torch.jit.script 70 | def _batch_detection( 71 | batch_size: int, class_out, box_out, anchor_boxes, indices, classes, 72 | img_scale: Optional[torch.Tensor] = None, 73 | img_size: Optional[torch.Tensor] = None, 74 | max_det_per_image: int = 100, 75 | soft_nms: bool = False, 76 | ): 77 | batch_detections = [] 78 | # FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome 79 | for i in range(batch_size): 80 | img_scale_i = None if img_scale is None else img_scale[i] 81 | img_size_i = None if img_size is None else img_size[i] 82 | detections = generate_detections( 83 | class_out[i], 84 | box_out[i], 85 | anchor_boxes, 86 | indices[i], 87 | classes[i], 88 | img_scale_i, 89 | img_size_i, 90 | max_det_per_image=max_det_per_image, 91 | soft_nms=soft_nms, 92 | ) 93 | batch_detections.append(detections) 94 | return torch.stack(batch_detections, dim=0) 95 | 96 | 97 | class DetBenchPredict(nn.Module): 98 | def __init__(self, model): 99 | super(DetBenchPredict, self).__init__() 100 | self.model = model 101 | self.config = model.config # FIXME remove this when we can use @property (torchscript limitation) 102 | self.num_levels = model.config.num_levels 103 | self.num_classes = model.config.num_classes 104 | self.anchors = Anchors.from_config(model.config) 105 | self.max_detection_points = model.config.max_detection_points 106 | self.max_det_per_image = model.config.max_det_per_image 107 | self.soft_nms = model.config.soft_nms 108 | 109 | def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None): 110 | class_out, box_out = self.model(x) 111 | class_out, box_out, indices, classes = _post_process( 112 | class_out, 113 | box_out, 114 | num_levels=self.num_levels, 115 | num_classes=self.num_classes, 116 | max_detection_points=self.max_detection_points, 117 | ) 118 | if img_info is None: 119 | img_scale, img_size = None, None 120 | else: 121 | img_scale, img_size = img_info['img_scale'], img_info['img_size'] 122 | return _batch_detection( 123 | x.shape[0], 124 | class_out, 125 | box_out, 126 | self.anchors.boxes, 127 | indices, 128 | classes, 129 | img_scale, 130 | img_size, 131 | max_det_per_image=self.max_det_per_image, 132 | soft_nms=self.soft_nms, 133 | ) 134 | 135 | 136 | class DetBenchTrain(nn.Module): 137 | def __init__(self, model, create_labeler=True): 138 | super(DetBenchTrain, self).__init__() 139 | self.model = model 140 | self.config = model.config # FIXME remove this when we can use @property (torchscript limitation) 141 | self.num_levels = model.config.num_levels 142 | self.num_classes = model.config.num_classes 143 | self.anchors = Anchors.from_config(model.config) 144 | self.max_detection_points = model.config.max_detection_points 145 | self.max_det_per_image = model.config.max_det_per_image 146 | self.soft_nms = model.config.soft_nms 147 | self.anchor_labeler = None 148 | if create_labeler: 149 | self.anchor_labeler = AnchorLabeler( 150 | self.anchors, 151 | self.num_classes, 152 | match_threshold=0.5, 153 | ) 154 | self.loss_fn = DetectionLoss(model.config) 155 | 156 | def forward(self, x, target: Dict[str, torch.Tensor]): 157 | class_out, box_out = self.model(x) 158 | if self.anchor_labeler is None: 159 | # target should contain pre-computed anchor labels if labeler not present in bench 160 | assert 'label_num_positives' in target 161 | cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)] 162 | box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)] 163 | num_positives = target['label_num_positives'] 164 | else: 165 | cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors( 166 | target['bbox'], 167 | target['cls'], 168 | ) 169 | 170 | loss, class_loss, box_loss = self.loss_fn( 171 | class_out, 172 | box_out, 173 | cls_targets, 174 | box_targets, 175 | num_positives, 176 | ) 177 | output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss} 178 | if not self.training: 179 | # if eval mode, output detections for evaluation 180 | class_out_pp, box_out_pp, indices, classes = _post_process( 181 | class_out, 182 | box_out, 183 | num_levels=self.num_levels, 184 | num_classes=self.num_classes, 185 | max_detection_points=self.max_detection_points, 186 | ) 187 | output['detections'] = _batch_detection( 188 | x.shape[0], 189 | class_out_pp, 190 | box_out_pp, 191 | self.anchors.boxes, 192 | indices, 193 | classes, 194 | target['img_scale'], 195 | target['img_size'], 196 | max_det_per_image=self.max_det_per_image, 197 | soft_nms=self.soft_nms, 198 | ) 199 | return output 200 | 201 | 202 | def unwrap_bench(model): 203 | # Unwrap a model in support bench so that various other fns can access the weights and attribs of the 204 | # underlying model directly 205 | if hasattr(model, 'module'): # unwrap DDP or EMA 206 | return unwrap_bench(model.module) 207 | elif hasattr(model, 'model'): # unwrap Bench -> model 208 | return unwrap_bench(model.model) 209 | else: 210 | return model 211 | -------------------------------------------------------------------------------- /effdet/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_utils import set_config_readonly, set_config_writeable 2 | from .fpn_config import get_fpn_config 3 | from .model_config import get_efficientdet_config, default_detection_model_configs 4 | from .train_config import default_detection_train_config 5 | -------------------------------------------------------------------------------- /effdet/config/config_utils.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | 4 | def set_config_readonly(conf): 5 | OmegaConf.set_readonly(conf, True) 6 | 7 | 8 | def set_config_writeable(conf): 9 | OmegaConf.set_readonly(conf, False) 10 | -------------------------------------------------------------------------------- /effdet/config/fpn_config.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from omegaconf import OmegaConf 4 | 5 | 6 | def bifpn_config(min_level, max_level, weight_method=None): 7 | """BiFPN config. 8 | Adapted from https://github.com/google/automl/blob/56815c9986ffd4b508fe1d68508e268d129715c1/efficientdet/keras/fpn_configs.py 9 | """ 10 | p = OmegaConf.create() 11 | weight_method = weight_method or 'fastattn' 12 | 13 | num_levels = max_level - min_level + 1 14 | node_ids = {min_level + i: [i] for i in range(num_levels)} 15 | 16 | level_last_id = lambda level: node_ids[level][-1] 17 | level_all_ids = lambda level: node_ids[level] 18 | id_cnt = itertools.count(num_levels) 19 | 20 | p.nodes = [] 21 | for i in range(max_level - 1, min_level - 1, -1): 22 | # top-down path. 23 | p.nodes.append({ 24 | 'feat_level': i, 25 | 'inputs_offsets': [level_last_id(i), level_last_id(i + 1)], 26 | 'weight_method': weight_method, 27 | }) 28 | node_ids[i].append(next(id_cnt)) 29 | 30 | for i in range(min_level + 1, max_level + 1): 31 | # bottom-up path. 32 | p.nodes.append({ 33 | 'feat_level': i, 34 | 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)], 35 | 'weight_method': weight_method, 36 | }) 37 | node_ids[i].append(next(id_cnt)) 38 | return p 39 | 40 | 41 | def panfpn_config(min_level, max_level, weight_method=None): 42 | """PAN FPN config. 43 | 44 | This defines FPN layout from Path Aggregation Networks as an alternate to 45 | BiFPN, it does not implement the full PAN spec. 46 | 47 | Paper: https://arxiv.org/abs/1803.01534 48 | """ 49 | p = OmegaConf.create() 50 | weight_method = weight_method or 'fastattn' 51 | 52 | num_levels = max_level - min_level + 1 53 | node_ids = {min_level + i: [i] for i in range(num_levels)} 54 | level_last_id = lambda level: node_ids[level][-1] 55 | id_cnt = itertools.count(num_levels) 56 | 57 | p.nodes = [] 58 | for i in range(max_level, min_level - 1, -1): 59 | # top-down path. 60 | offsets = [level_last_id(i), level_last_id(i + 1)] if i != max_level else [level_last_id(i)] 61 | p.nodes.append({ 62 | 'feat_level': i, 63 | 'inputs_offsets': offsets, 64 | 'weight_method': weight_method, 65 | }) 66 | node_ids[i].append(next(id_cnt)) 67 | 68 | for i in range(min_level, max_level + 1): 69 | # bottom-up path. 70 | offsets = [level_last_id(i), level_last_id(i - 1)] if i != min_level else [level_last_id(i)] 71 | p.nodes.append({ 72 | 'feat_level': i, 73 | 'inputs_offsets': offsets, 74 | 'weight_method': weight_method, 75 | }) 76 | node_ids[i].append(next(id_cnt)) 77 | 78 | return p 79 | 80 | 81 | def qufpn_config(min_level, max_level, weight_method=None): 82 | """A dynamic quad fpn config that can adapt to different min/max levels. 83 | 84 | It extends the idea of BiFPN, and has four paths: 85 | (up_down -> bottom_up) + (bottom_up -> up_down). 86 | 87 | Paper: https://ieeexplore.ieee.org/document/9225379 88 | Ref code: From contribution to TF EfficientDet 89 | https://github.com/google/automl/blob/eb74c6739382e9444817d2ad97c4582dbe9a9020/efficientdet/keras/fpn_configs.py 90 | """ 91 | p = OmegaConf.create() 92 | weight_method = weight_method or 'fastattn' 93 | quad_method = 'fastattn' 94 | num_levels = max_level - min_level + 1 95 | node_ids = {min_level + i: [i] for i in range(num_levels)} 96 | level_last_id = lambda level: node_ids[level][-1] 97 | level_all_ids = lambda level: node_ids[level] 98 | level_first_id = lambda level: node_ids[level][0] 99 | id_cnt = itertools.count(num_levels) 100 | 101 | p.nodes = [] 102 | for i in range(max_level - 1, min_level - 1, -1): 103 | # top-down path 1. 104 | p.nodes.append({ 105 | 'feat_level': i, 106 | 'inputs_offsets': [level_last_id(i), level_last_id(i + 1)], 107 | 'weight_method': weight_method 108 | }) 109 | node_ids[i].append(next(id_cnt)) 110 | node_ids[max_level].append(node_ids[max_level][-1]) 111 | 112 | for i in range(min_level + 1, max_level): 113 | # bottom-up path 2. 114 | p.nodes.append({ 115 | 'feat_level': i, 116 | 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)], 117 | 'weight_method': weight_method 118 | }) 119 | node_ids[i].append(next(id_cnt)) 120 | 121 | i = max_level 122 | p.nodes.append({ 123 | 'feat_level': i, 124 | 'inputs_offsets': [level_first_id(i)] + [level_last_id(i - 1)], 125 | 'weight_method': weight_method 126 | }) 127 | node_ids[i].append(next(id_cnt)) 128 | node_ids[min_level].append(node_ids[min_level][-1]) 129 | 130 | for i in range(min_level + 1, max_level + 1, 1): 131 | # bottom-up path 3. 132 | p.nodes.append({ 133 | 'feat_level': i, 134 | 'inputs_offsets': [ 135 | level_first_id(i), level_last_id(i - 1) if i != min_level + 1 else level_first_id(i - 1)], 136 | 'weight_method': weight_method 137 | }) 138 | node_ids[i].append(next(id_cnt)) 139 | node_ids[min_level].append(node_ids[min_level][-1]) 140 | 141 | for i in range(max_level - 1, min_level, -1): 142 | # top-down path 4. 143 | p.nodes.append({ 144 | 'feat_level': i, 145 | 'inputs_offsets': [node_ids[i][0]] + [node_ids[i][-1]] + [level_last_id(i + 1)], 146 | 'weight_method': weight_method 147 | }) 148 | node_ids[i].append(next(id_cnt)) 149 | i = min_level 150 | p.nodes.append({ 151 | 'feat_level': i, 152 | 'inputs_offsets': [node_ids[i][0]] + [level_last_id(i + 1)], 153 | 'weight_method': weight_method 154 | }) 155 | node_ids[i].append(next(id_cnt)) 156 | node_ids[max_level].append(node_ids[max_level][-1]) 157 | 158 | # NOTE: the order of the quad path is reversed from the original, my code expects the output of 159 | # each FPN repeat to be same as input from backbone, in order of increasing reductions 160 | for i in range(min_level, max_level + 1): 161 | # quad-add path. 162 | p.nodes.append({ 163 | 'feat_level': i, 164 | 'inputs_offsets': [node_ids[i][2], node_ids[i][4]], 165 | 'weight_method': quad_method 166 | }) 167 | node_ids[i].append(next(id_cnt)) 168 | 169 | return p 170 | 171 | 172 | def get_fpn_config(fpn_name, min_level=3, max_level=7): 173 | if not fpn_name: 174 | fpn_name = 'bifpn_fa' 175 | name_to_config = { 176 | 'bifpn_sum': bifpn_config(min_level=min_level, max_level=max_level, weight_method='sum'), 177 | 'bifpn_attn': bifpn_config(min_level=min_level, max_level=max_level, weight_method='attn'), 178 | 'bifpn_fa': bifpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'), 179 | 'pan_sum': panfpn_config(min_level=min_level, max_level=max_level, weight_method='sum'), 180 | 'pan_fa': panfpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'), 181 | 'qufpn_sum': qufpn_config(min_level=min_level, max_level=max_level, weight_method='sum'), 182 | 'qufpn_fa': qufpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'), 183 | } 184 | return name_to_config[fpn_name] 185 | -------------------------------------------------------------------------------- /effdet/config/train_config.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | 4 | def default_detection_train_config(): 5 | # FIXME currently using args for train config, will revisit, perhaps move to Hydra 6 | h = OmegaConf.create() 7 | 8 | # dataset 9 | h.skip_crowd_during_training = True 10 | 11 | # augmentation 12 | h.input_rand_hflip = True 13 | h.train_scale_min = 0.1 14 | h.train_scale_max = 2.0 15 | h.autoaugment_policy = None 16 | 17 | # optimization 18 | h.momentum = 0.9 19 | h.learning_rate = 0.08 20 | h.lr_warmup_init = 0.008 21 | h.lr_warmup_epoch = 1.0 22 | h.first_lr_drop_epoch = 200.0 23 | h.second_lr_drop_epoch = 250.0 24 | h.clip_gradients_norm = 10.0 25 | h.num_epochs = 300 26 | 27 | # regularization l2 loss. 28 | h.weight_decay = 4e-5 29 | 30 | h.lr_decay_method = 'cosine' 31 | h.moving_average_decay = 0.9998 32 | h.ckpt_var_scope = None 33 | 34 | return h 35 | -------------------------------------------------------------------------------- /effdet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_factory import create_dataset 2 | from .dataset import DetectionDatset, SkipSubset 3 | from .input_config import resolve_input_config 4 | from .loader import create_loader 5 | from .parsers import create_parser 6 | from .transforms import * 7 | -------------------------------------------------------------------------------- /effdet/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ Detection dataset 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch.utils.data as data 6 | import numpy as np 7 | 8 | from PIL import Image 9 | from .parsers import create_parser 10 | 11 | 12 | class DetectionDatset(data.Dataset): 13 | """`Object Detection Dataset. Use with parsers for COCO, VOC, and OpenImages. 14 | Args: 15 | parser (string, Parser): 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.ToTensor`` 18 | 19 | """ 20 | 21 | def __init__(self, data_dir, parser=None, parser_kwargs=None, transform=None): 22 | super(DetectionDatset, self).__init__() 23 | parser_kwargs = parser_kwargs or {} 24 | self.data_dir = data_dir 25 | if isinstance(parser, str): 26 | self._parser = create_parser(parser, **parser_kwargs) 27 | else: 28 | assert parser is not None and len(parser.img_ids) 29 | self._parser = parser 30 | self._transform = transform 31 | 32 | def __getitem__(self, index): 33 | """ 34 | Args: 35 | index (int): Index 36 | Returns: 37 | tuple: Tuple (image, annotations (target)). 38 | """ 39 | img_info = self._parser.img_infos[index] 40 | target = dict(img_idx=index, img_size=(img_info['width'], img_info['height'])) 41 | if self._parser.has_labels: 42 | ann = self._parser.get_ann_info(index) 43 | target.update(ann) 44 | 45 | img_path = self.data_dir / img_info['file_name'] 46 | img = Image.open(img_path).convert('RGB') 47 | if self.transform is not None: 48 | img, target = self.transform(img, target) 49 | 50 | return img, target 51 | 52 | def __len__(self): 53 | return len(self._parser.img_ids) 54 | 55 | @property 56 | def parser(self): 57 | return self._parser 58 | 59 | @property 60 | def transform(self): 61 | return self._transform 62 | 63 | @transform.setter 64 | def transform(self, t): 65 | self._transform = t 66 | 67 | 68 | class SkipSubset(data.Dataset): 69 | r""" 70 | Subset of a dataset at specified indices. 71 | 72 | Arguments: 73 | dataset (Dataset): The whole Dataset 74 | n (int): skip rate (select every nth) 75 | """ 76 | def __init__(self, dataset, n=2): 77 | self.dataset = dataset 78 | assert n >= 1 79 | self.indices = np.arange(len(dataset))[::n] 80 | 81 | def __getitem__(self, idx): 82 | return self.dataset[self.indices[idx]] 83 | 84 | def __len__(self): 85 | return len(self.indices) 86 | 87 | @property 88 | def parser(self): 89 | return self.dataset.parser 90 | 91 | @property 92 | def transform(self): 93 | return self.dataset.transform 94 | 95 | @transform.setter 96 | def transform(self, t): 97 | self.dataset.transform = t 98 | -------------------------------------------------------------------------------- /effdet/data/dataset_config.py: -------------------------------------------------------------------------------- 1 | """ COCO, VOC, OpenImages dataset configurations 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | from dataclasses import dataclass, field 7 | from typing import Dict 8 | 9 | 10 | @dataclass 11 | class CocoCfg: 12 | variant: str = None 13 | parser: str = 'coco' 14 | num_classes: int = 80 15 | splits: Dict[str, dict] = None 16 | 17 | 18 | @dataclass 19 | class Coco2017Cfg(CocoCfg): 20 | variant: str = '2017' 21 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 22 | train=dict(ann_filename='annotations/instances_train2017.json', img_dir='train2017', has_labels=True), 23 | val=dict(ann_filename='annotations/instances_val2017.json', img_dir='val2017', has_labels=True), 24 | test=dict(ann_filename='annotations/image_info_test2017.json', img_dir='test2017', has_labels=False), 25 | testdev=dict(ann_filename='annotations/image_info_test-dev2017.json', img_dir='test2017', has_labels=False), 26 | )) 27 | 28 | 29 | @dataclass 30 | class Coco2014Cfg(CocoCfg): 31 | variant: str = '2014' 32 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 33 | train=dict(ann_filename='annotations/instances_train2014.json', img_dir='train2014', has_labels=True), 34 | val=dict(ann_filename='annotations/instances_val2014.json', img_dir='val2014', has_labels=True), 35 | test=dict(ann_filename='', img_dir='test2014', has_labels=False), 36 | )) 37 | 38 | 39 | @dataclass 40 | class VocCfg: 41 | variant: str = None 42 | parser: str = 'voc' 43 | num_classes: int = 80 44 | img_filename: str = '%s.jpg' 45 | splits: Dict[str, dict] = None 46 | 47 | 48 | @dataclass 49 | class Voc2007Cfg(VocCfg): 50 | variant: str = '2007' 51 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 52 | train=dict( 53 | split_filename='VOC2007/ImageSets/Main/train.txt', 54 | ann_filename='VOC2007/Annotations/%s.xml', 55 | img_dir='VOC2007/JPEGImages', ), 56 | val=dict( 57 | split_filename='VOC2007/ImageSets/Main/val.txt', 58 | ann_filename='VOC2007/Annotations/%s.xml', 59 | img_dir='VOC2007/JPEGImages'), 60 | #test=dict(img_dir='JPEGImages') 61 | )) 62 | 63 | 64 | @dataclass 65 | class Voc2012Cfg(VocCfg): 66 | variant: str = '2012' 67 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 68 | train=dict( 69 | split_filename='VOC2012/ImageSets/Main/train.txt', 70 | ann_filename='VOC2012/Annotations/%s.xml', 71 | img_dir='VOC2012/JPEGImages'), 72 | val=dict( 73 | split_filename='VOC2012/ImageSets/Main/val.txt', 74 | ann_filename='VOC2012/Annotations/%s.xml', 75 | img_dir='VOC2012/JPEGImages'), 76 | #test=dict(img_dir='JPEGImages', split_file=None) 77 | )) 78 | 79 | 80 | @dataclass 81 | class Voc0712Cfg(VocCfg): 82 | variant: str = '0712' 83 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 84 | train=dict( 85 | split_filename=['VOC2007/ImageSets/Main/trainval.txt', 'VOC2012/ImageSets/Main/trainval.txt'], 86 | ann_filename=['VOC2007/Annotations/%s.xml', 'VOC2012/Annotations/%s.xml'], 87 | img_dir=['VOC2007/JPEGImages', 'VOC2012/JPEGImages']), 88 | val=dict( 89 | split_filename='VOC2007/ImageSets/Main/test.txt', 90 | ann_filename='VOC2007/Annotations/%s.xml', 91 | img_dir='VOC2007/JPEGImages'), 92 | #test=dict(img_dir='JPEGImages', split_file=None) 93 | )) 94 | 95 | 96 | 97 | @dataclass 98 | class OpenImagesCfg: 99 | variant: str = None 100 | parser: str = 'openimages' 101 | num_classes: int = None 102 | img_filename = '%s.jpg' 103 | splits: Dict[str, dict] = None 104 | 105 | 106 | @dataclass 107 | class OpenImagesObjCfg(OpenImagesCfg): 108 | num_classes: int = 601 109 | categories_map: str = 'annotations/class-descriptions-boxable.csv' 110 | 111 | 112 | @dataclass 113 | class OpenImagesSegCfg(OpenImagesCfg): 114 | num_classes: int = 350 115 | categories_map: str = 'annotations/classes-segmentation.txt' 116 | 117 | 118 | @dataclass 119 | class OpenImagesObjV5Cfg(OpenImagesObjCfg): 120 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 121 | train=dict( 122 | img_dir='train', img_info='annotations/train-info.csv', has_labels=True, prefix_levels=1, 123 | ann_bbox='annotations/train-annotations-bbox.csv', 124 | ann_img_label='annotations/train-annotations-human-imagelabels-boxable.csv', 125 | ), 126 | val=dict( 127 | img_dir='validation', img_info='annotations/validation-info.csv', has_labels=True, prefix_levels=0, 128 | ann_bbox='annotations/validation-annotations-bbox.csv', 129 | ann_img_label='annotations/validation-annotations-human-imagelabels-boxable.csv', 130 | ), 131 | test=dict( 132 | img_dir='test', img_info='', has_labels=True, prefix_levels=0, 133 | ann_bbox='annotations/test-annotations-bbox.csv', 134 | ann_img_label='annotations/test-annotations-human-imagelabels-boxable.csv', 135 | ) 136 | )) 137 | 138 | 139 | @dataclass 140 | class OpenImagesObjChallenge2019Cfg(OpenImagesObjCfg): 141 | num_classes: int = 500 142 | categories_map: str = 'annotations/challenge-2019/challenge-2019-classes-description-500.csv' 143 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 144 | train=dict( 145 | img_dir='train', img_info='annotations/train-info.csv', has_labels=True, prefix_levels=1, 146 | ann_bbox='annotations/challenge-2019/challenge-2019-train-detection-bbox.csv', 147 | ann_img_label='annotations/challenge-2019/challenge-2019-train-detection-human-imagelabels.csv', 148 | ), 149 | val=dict( 150 | img_dir='validation', img_info='annotations/validation-info.csv', has_labels=True, prefix_levels=0, 151 | ann_bbox='annotations/challenge-2019/challenge-2019-validation-detection-bbox.csv', 152 | ann_img_label='annotations/challenge-2019/challenge-2019-validation-detection-human-imagelabels.csv', 153 | ), 154 | test=dict( 155 | img_dir='challenge2019', img_info='annotations/challenge-2019/challenge2019-info', prefix_levels=0, 156 | has_labels=False, ann_bbox='', ann_img_label='', 157 | ) 158 | )) 159 | 160 | 161 | @dataclass 162 | class OpenImagesSegV5Cfg(OpenImagesSegCfg): 163 | num_classes: int = 300 164 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 165 | train=dict(), 166 | val=dict(), 167 | test=dict() 168 | )) 169 | 170 | 171 | @dataclass 172 | class OpenImagesSegChallenge2019Cfg(OpenImagesSegCfg): 173 | num_classes: int = 300 174 | ann_class_map: str = 'annotations/challenge-2019/challenge-2019-classes-description-segmentable.csv' 175 | splits: Dict[str, dict] = field(default_factory=lambda: dict( 176 | train=dict(), 177 | val=dict(), 178 | test=dict() 179 | )) -------------------------------------------------------------------------------- /effdet/data/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """ Dataset factory 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | from collections import OrderedDict 7 | from pathlib import Path 8 | 9 | from .dataset_config import * 10 | from .parsers import * 11 | from .dataset import DetectionDatset 12 | from .parsers import create_parser 13 | 14 | 15 | def create_dataset(name, root, splits=('train', 'val')): 16 | if isinstance(splits, str): 17 | splits = (splits,) 18 | name = name.lower() 19 | root = Path(root) 20 | dataset_cls = DetectionDatset 21 | datasets = OrderedDict() 22 | if name.startswith('coco'): 23 | if 'coco2014' in name: 24 | dataset_cfg = Coco2014Cfg() 25 | else: 26 | dataset_cfg = Coco2017Cfg() 27 | for s in splits: 28 | if s not in dataset_cfg.splits: 29 | raise RuntimeError(f'{s} split not found in config') 30 | split_cfg = dataset_cfg.splits[s] 31 | ann_file = root / split_cfg['ann_filename'] 32 | parser_cfg = CocoParserCfg( 33 | ann_filename=ann_file, 34 | has_labels=split_cfg['has_labels'] 35 | ) 36 | datasets[s] = dataset_cls( 37 | data_dir=root / Path(split_cfg['img_dir']), 38 | parser=create_parser(dataset_cfg.parser, cfg=parser_cfg), 39 | ) 40 | elif name.startswith('voc'): 41 | if 'voc0712' in name: 42 | dataset_cfg = Voc0712Cfg() 43 | elif 'voc2007' in name: 44 | dataset_cfg = Voc2007Cfg() 45 | else: 46 | dataset_cfg = Voc2012Cfg() 47 | for s in splits: 48 | if s not in dataset_cfg.splits: 49 | raise RuntimeError(f'{s} split not found in config') 50 | split_cfg = dataset_cfg.splits[s] 51 | if isinstance(split_cfg['split_filename'], (tuple, list)): 52 | assert len(split_cfg['split_filename']) == len(split_cfg['ann_filename']) 53 | parser = None 54 | for sf, af, id in zip( 55 | split_cfg['split_filename'], split_cfg['ann_filename'], split_cfg['img_dir']): 56 | parser_cfg = VocParserCfg( 57 | split_filename=root / sf, 58 | ann_filename=os.path.join(root, af), 59 | img_filename=os.path.join(id, dataset_cfg.img_filename)) 60 | if parser is None: 61 | parser = create_parser(dataset_cfg.parser, cfg=parser_cfg) 62 | else: 63 | other_parser = create_parser(dataset_cfg.parser, cfg=parser_cfg) 64 | parser.merge(other=other_parser) 65 | else: 66 | parser_cfg = VocParserCfg( 67 | split_filename=root / split_cfg['split_filename'], 68 | ann_filename=os.path.join(root, split_cfg['ann_filename']), 69 | img_filename=os.path.join(split_cfg['img_dir'], dataset_cfg.img_filename), 70 | ) 71 | parser = create_parser(dataset_cfg.parser, cfg=parser_cfg) 72 | datasets[s] = dataset_cls(data_dir=root, parser=parser) 73 | elif name.startswith('openimages'): 74 | if 'challenge2019' in name: 75 | dataset_cfg = OpenImagesObjChallenge2019Cfg() 76 | else: 77 | dataset_cfg = OpenImagesObjV5Cfg() 78 | for s in splits: 79 | if s not in dataset_cfg.splits: 80 | raise RuntimeError(f'{s} split not found in config') 81 | split_cfg = dataset_cfg.splits[s] 82 | parser_cfg = OpenImagesParserCfg( 83 | categories_filename=root / dataset_cfg.categories_map, 84 | img_info_filename=root / split_cfg['img_info'], 85 | bbox_filename=root / split_cfg['ann_bbox'], 86 | img_label_filename=root / split_cfg['ann_img_label'], 87 | img_filename=dataset_cfg.img_filename, 88 | prefix_levels=split_cfg['prefix_levels'], 89 | has_labels=split_cfg['has_labels'], 90 | ) 91 | datasets[s] = dataset_cls( 92 | data_dir=root / Path(split_cfg['img_dir']), 93 | parser=create_parser(dataset_cfg.parser, cfg=parser_cfg) 94 | ) 95 | else: 96 | assert False, f'Unknown dataset parser ({name})' 97 | 98 | datasets = list(datasets.values()) 99 | return datasets if len(datasets) > 1 else datasets[0] 100 | -------------------------------------------------------------------------------- /effdet/data/input_config.py: -------------------------------------------------------------------------------- 1 | from .transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 2 | 3 | 4 | def resolve_input_config(args, model_config=None, model=None): 5 | if not isinstance(args, dict): 6 | args = vars(args) 7 | input_config = {} 8 | if not model_config and model is not None and hasattr(model, 'config'): 9 | model_config = model.config 10 | 11 | # Resolve input/image size 12 | in_chans = 3 13 | # if 'chans' in args and args['chans'] is not None: 14 | # in_chans = args['chans'] 15 | 16 | input_size = (in_chans, 512, 512) 17 | # if 'input_size' in args and args['input_size'] is not None: 18 | # assert isinstance(args['input_size'], (tuple, list)) 19 | # assert len(args['input_size']) == 3 20 | # input_size = tuple(args['input_size']) 21 | # in_chans = input_size[0] # input_size overrides in_chans 22 | # elif 'img_size' in args and args['img_size'] is not None: 23 | # assert isinstance(args['img_size'], int) 24 | # input_size = (in_chans, args['img_size'], args['img_size']) 25 | if 'input_size' in model_config: 26 | input_size = tuple(model_config['input_size']) 27 | elif 'image_size' in model_config: 28 | input_size = (in_chans,) + tuple(model_config['image_size']) 29 | assert isinstance(input_size, tuple) and len(input_size) == 3 30 | input_config['input_size'] = input_size 31 | 32 | # resolve interpolation method 33 | input_config['interpolation'] = 'bicubic' 34 | if 'interpolation' in args and args['interpolation']: 35 | input_config['interpolation'] = args['interpolation'] 36 | elif 'interpolation' in model_config: 37 | input_config['interpolation'] = model_config['interpolation'] 38 | 39 | # resolve dataset + model mean for normalization 40 | input_config['mean'] = IMAGENET_DEFAULT_MEAN 41 | if 'mean' in args and args['mean'] is not None: 42 | mean = tuple(args['mean']) 43 | if len(mean) == 1: 44 | mean = tuple(list(mean) * in_chans) 45 | else: 46 | assert len(mean) == in_chans 47 | input_config['mean'] = mean 48 | elif 'mean' in model_config: 49 | input_config['mean'] = model_config['mean'] 50 | 51 | # resolve dataset + model std deviation for normalization 52 | input_config['std'] = IMAGENET_DEFAULT_STD 53 | if 'std' in args and args['std'] is not None: 54 | std = tuple(args['std']) 55 | if len(std) == 1: 56 | std = tuple(list(std) * in_chans) 57 | else: 58 | assert len(std) == in_chans 59 | input_config['std'] = std 60 | elif 'std' in model_config: 61 | input_config['std'] = model_config['std'] 62 | 63 | # resolve letterbox fill color 64 | input_config['fill_color'] = 'mean' 65 | if 'fill_color' in args and args['fill_color'] is not None: 66 | input_config['fill_color'] = args['fill_color'] 67 | elif 'fill_color' in model_config: 68 | input_config['fill_color'] = model_config['fill_color'] 69 | 70 | return input_config 71 | -------------------------------------------------------------------------------- /effdet/data/loader.py: -------------------------------------------------------------------------------- 1 | """ Object detection loader/collate 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch.utils.data 6 | from .transforms import * 7 | from .random_erasing import RandomErasing 8 | from effdet.anchors import AnchorLabeler 9 | from timm.data.distributed_sampler import OrderedDistributedSampler 10 | import os 11 | 12 | MAX_NUM_INSTANCES = 100 13 | 14 | 15 | class DetectionFastCollate: 16 | """ A detection specific, optimized collate function w/ a bit of state. 17 | 18 | Optionally performs anchor labelling. Doing this here offloads some work from the 19 | GPU and the main training process thread and increases the load on the dataloader 20 | threads. 21 | 22 | """ 23 | def __init__( 24 | self, 25 | instance_keys=None, 26 | instance_shapes=None, 27 | instance_fill=-1, 28 | max_instances=MAX_NUM_INSTANCES, 29 | anchor_labeler=None, 30 | ): 31 | instance_keys = instance_keys or {'bbox', 'bbox_ignore', 'cls'} 32 | instance_shapes = instance_shapes or dict( 33 | bbox=(max_instances, 4), bbox_ignore=(max_instances, 4), cls=(max_instances,)) 34 | self.instance_info = {k: dict(fill=instance_fill, shape=instance_shapes[k]) for k in instance_keys} 35 | self.max_instances = max_instances 36 | self.anchor_labeler = anchor_labeler 37 | 38 | def __call__(self, batch): 39 | batch_size = len(batch) 40 | target = dict() 41 | labeler_outputs = dict() 42 | img_tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 43 | for i in range(batch_size): 44 | img_tensor[i] += torch.from_numpy(batch[i][0]) 45 | labeler_inputs = {} 46 | for tk, tv in batch[i][1].items(): 47 | instance_info = self.instance_info.get(tk, None) 48 | if instance_info is not None: 49 | # target tensor is associated with a detection instance 50 | tv = torch.from_numpy(tv).to(dtype=torch.float32) 51 | if self.anchor_labeler is None: 52 | if i == 0: 53 | shape = (batch_size,) + instance_info['shape'] 54 | target_tensor = torch.full(shape, instance_info['fill'], dtype=torch.float32) 55 | target[tk] = target_tensor 56 | else: 57 | target_tensor = target[tk] 58 | num_elem = min(tv.shape[0], self.max_instances) 59 | target_tensor[i, 0:num_elem] = tv[0:num_elem] 60 | else: 61 | # no need to pass gt tensors through when labeler in use 62 | if tk in ('bbox', 'cls'): 63 | labeler_inputs[tk] = tv 64 | else: 65 | # target tensor is an image-level annotation / metadata 66 | if i == 0: 67 | # first batch elem, create destination tensors 68 | if isinstance(tv, (tuple, list)): 69 | # per batch elem sequence 70 | shape = (batch_size, len(tv)) 71 | dtype = torch.float32 if isinstance(tv[0], (float, np.floating)) else torch.int32 72 | else: 73 | # per batch elem scalar 74 | shape = batch_size, 75 | dtype = torch.float32 if isinstance(tv, (float, np.floating)) else torch.int64 76 | target_tensor = torch.zeros(shape, dtype=dtype) 77 | target[tk] = target_tensor 78 | else: 79 | target_tensor = target[tk] 80 | target_tensor[i] = torch.tensor(tv, dtype=target_tensor.dtype) 81 | 82 | if self.anchor_labeler is not None: 83 | cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors( 84 | labeler_inputs['bbox'], labeler_inputs['cls'], filter_valid=False) 85 | if i == 0: 86 | # first batch elem, create destination tensors, separate key per level 87 | for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)): 88 | labeler_outputs[f'label_cls_{j}'] = torch.zeros( 89 | (batch_size,) + ct.shape, dtype=torch.int64) 90 | labeler_outputs[f'label_bbox_{j}'] = torch.zeros( 91 | (batch_size,) + bt.shape, dtype=torch.float32) 92 | labeler_outputs['label_num_positives'] = torch.zeros(batch_size) 93 | for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)): 94 | labeler_outputs[f'label_cls_{j}'][i] = ct 95 | labeler_outputs[f'label_bbox_{j}'][i] = bt 96 | labeler_outputs['label_num_positives'][i] = num_positives 97 | if labeler_outputs: 98 | target.update(labeler_outputs) 99 | 100 | return img_tensor, target 101 | 102 | 103 | class PrefetchLoader: 104 | 105 | def __init__(self, 106 | loader, 107 | mean=IMAGENET_DEFAULT_MEAN, 108 | std=IMAGENET_DEFAULT_STD, 109 | re_prob=0., 110 | re_mode='pixel', 111 | re_count=1, 112 | device="cuda" 113 | ): 114 | self.loader = loader 115 | self.mean = torch.tensor([x * 255 for x in mean]).to(device).view(1, 3, 1, 1) 116 | self.std = torch.tensor([x * 255 for x in std]).to(device).view(1, 3, 1, 1) 117 | if re_prob > 0.: 118 | self.random_erasing = RandomErasing(probability=re_prob, mode=re_mode, max_count=re_count) 119 | else: 120 | self.random_erasing = None 121 | self.device = device 122 | 123 | def __iter__(self): 124 | first = True 125 | if self.device == "cuda": 126 | stream = torch.cuda.Stream() 127 | 128 | for next_input, next_target in self.loader: 129 | if self.device == "cuda": 130 | with torch.cuda.stream(stream): 131 | next_input = next_input.cuda(non_blocking=True) 132 | next_input = next_input.float().sub_(self.mean).div_(self.std) 133 | next_target = {k: v.cuda(non_blocking=True) for k, v in next_target.items()} 134 | if self.random_erasing is not None: 135 | next_input = self.random_erasing(next_input, next_target) 136 | if not first: 137 | yield input, target 138 | else: 139 | first = False 140 | torch.cuda.current_stream().wait_stream(stream) 141 | else: 142 | next_input = next_input.to(self.device) 143 | next_input = next_input.float().sub_(self.mean).div_(self.std) 144 | next_target = {k: v.to(self.device) for k, v in next_target.items()} 145 | if self.random_erasing is not None: 146 | next_input = self.random_erasing(next_input, next_target) 147 | 148 | if not first: 149 | yield input, target 150 | else: 151 | first = False 152 | 153 | input = next_input 154 | target = next_target 155 | 156 | yield input, target 157 | 158 | def __len__(self): 159 | return len(self.loader) 160 | 161 | @property 162 | def sampler(self): 163 | return self.loader.sampler 164 | 165 | @property 166 | def dataset(self): 167 | return self.loader.dataset 168 | 169 | 170 | def create_loader( 171 | dataset, 172 | input_size, 173 | batch_size, 174 | is_training=False, 175 | use_prefetcher=True, 176 | re_prob=0., 177 | re_mode='pixel', 178 | re_count=1, 179 | interpolation='bilinear', 180 | fill_color='mean', 181 | mean=IMAGENET_DEFAULT_MEAN, 182 | std=IMAGENET_DEFAULT_STD, 183 | num_workers=1, 184 | distributed=False, 185 | pin_mem=False, 186 | anchor_labeler=None, 187 | transform_fn=None, 188 | collate_fn=None, 189 | device="cuda" 190 | ): 191 | if isinstance(input_size, tuple): 192 | img_size = input_size[-2:] 193 | else: 194 | img_size = input_size 195 | 196 | if transform_fn is not None: 197 | # transform_fn should accept inputs (img, annotations) from the dataset and return a tuple 198 | # of img, annotations for the data loader collate function. 199 | # The valid types of img and annotations depend on the dataset and collate abstractions used. 200 | # The default dataset outputs PIL Image and dict of numpy ndarrays or python scalar annotations. 201 | # The fast collate fn accepts ONLY numpy uint8 images and annotations dicts of ndarrays and python scalars 202 | transform = transform_fn 203 | else: 204 | if is_training: 205 | transform = transforms_coco_train( 206 | img_size, 207 | interpolation=interpolation, 208 | use_prefetcher=use_prefetcher, 209 | fill_color=fill_color, 210 | mean=mean, 211 | std=std) 212 | else: 213 | transform = transforms_coco_eval( 214 | img_size, 215 | interpolation=interpolation, 216 | use_prefetcher=use_prefetcher, 217 | fill_color=fill_color, 218 | mean=mean, 219 | std=std) 220 | dataset.transform = transform 221 | 222 | sampler = None 223 | if distributed: 224 | if is_training: 225 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 226 | else: 227 | # This will add extra duplicate entries to result in equal num 228 | # of samples per-process, will slightly alter validation results 229 | sampler = OrderedDistributedSampler(dataset) 230 | 231 | collate_fn = collate_fn or DetectionFastCollate(anchor_labeler=anchor_labeler) 232 | loader = torch.utils.data.DataLoader( 233 | dataset, 234 | batch_size=batch_size, 235 | shuffle=sampler is None and is_training, 236 | num_workers=num_workers, 237 | sampler=sampler, 238 | pin_memory=pin_mem, 239 | collate_fn=collate_fn, 240 | ) 241 | if use_prefetcher: 242 | if is_training: 243 | loader = PrefetchLoader(loader, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, device=device) 244 | else: 245 | loader = PrefetchLoader(loader, mean=mean, std=std, device=device) 246 | 247 | return loader 248 | -------------------------------------------------------------------------------- /effdet/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_config import OpenImagesParserCfg, CocoParserCfg, VocParserCfg 2 | from .parser_factory import create_parser 3 | -------------------------------------------------------------------------------- /effdet/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from numbers import Integral 2 | from typing import List, Union, Dict, Any 3 | 4 | 5 | class Parser: 6 | """ Parser base class. 7 | 8 | The attributes listed below make up a public interface common to all parsers. They can be accessed directly 9 | once the dataset is constructed and annotations are populated. 10 | 11 | Attributes: 12 | 13 | cat_names (list[str]): 14 | list of category (class) names, with background class at position 0. 15 | cat_ids (list[union[str, int]): 16 | list of dataset specific, unique integer or string category ids, does not include background 17 | cat_id_to_label (dict): 18 | map from category id to integer 1-indexed class label 19 | 20 | img_ids (list): 21 | list of dataset specific, unique image ids corresponding to valid samples in dataset 22 | img_ids_invalid (list): 23 | list of image ids corresponding to invalid images, not used as samples 24 | img_infos (list[dict]): 25 | image info, list of info dicts with filename, width, height for each image sample 26 | """ 27 | def __init__( 28 | self, 29 | bbox_yxyx: bool = False, 30 | has_labels: bool = True, 31 | include_masks: bool = False, 32 | include_bboxes_ignore: bool = False, 33 | ignore_empty_gt: bool = False, 34 | min_img_size: int = 32, 35 | ): 36 | """ 37 | Args: 38 | yxyx (bool): output coords in yxyx format, otherwise xyxy 39 | has_labels (bool): dataset has labels (for training validation, False usually for test sets) 40 | include_masks (bool): include segmentation masks in target output (not supported yet for any dataset) 41 | include_bboxes_ignore (bool): include ignored bbox in target output 42 | ignore_empty_gt (bool): ignore images with no ground truth (no negative images) 43 | min_img_size (bool): ignore images with width or height smaller than this number 44 | sub_sample (int): sample every N images from the dataset 45 | """ 46 | # parser config, determines how dataset parsed and validated 47 | self.yxyx = bbox_yxyx 48 | self.has_labels = has_labels 49 | self.include_masks = include_masks 50 | self.include_bboxes_ignore = include_bboxes_ignore 51 | self.ignore_empty_gt = ignore_empty_gt 52 | self.min_img_size = min_img_size 53 | self.label_offset = 1 54 | 55 | # Category (class) metadata. Populated by _load_annotations() 56 | self.cat_names: List[str] = [] 57 | self.cat_ids: List[Union[str, Integral]] = [] 58 | self.cat_id_to_label: Dict[Union[str, Integral], Integral] = dict() 59 | 60 | # Image metadata. Populated by _load_annotations() 61 | self.img_ids: List[Union[str, Integral]] = [] 62 | self.img_ids_invalid: List[Union[str, Integral]] = [] 63 | self.img_infos: List[Dict[str, Any]] = [] 64 | 65 | @property 66 | def cat_dicts(self): 67 | """return category names and labels in format compatible with TF Models Evaluator 68 | list[dict(name=, id=)] 69 | """ 70 | return [ 71 | dict( 72 | name=name, 73 | id=cat_id if not self.cat_id_to_label else self.cat_id_to_label[cat_id] 74 | ) for name, cat_id in zip(self.cat_names, self.cat_ids)] 75 | 76 | @property 77 | def max_label(self): 78 | if self.cat_id_to_label: 79 | return max(self.cat_id_to_label.values()) 80 | else: 81 | assert len(self.cat_ids) and isinstance(self.cat_ids[0], Integral) 82 | return max(self.cat_ids) 83 | -------------------------------------------------------------------------------- /effdet/data/parsers/parser_coco.py: -------------------------------------------------------------------------------- 1 | """ COCO dataset parser 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import numpy as np 6 | from pycocotools.coco import COCO 7 | from .parser import Parser 8 | from .parser_config import CocoParserCfg 9 | 10 | 11 | class CocoParser(Parser): 12 | 13 | def __init__(self, cfg: CocoParserCfg): 14 | super().__init__( 15 | bbox_yxyx=cfg.bbox_yxyx, 16 | has_labels=cfg.has_labels, 17 | include_masks=cfg.include_masks, 18 | include_bboxes_ignore=cfg.include_bboxes_ignore, 19 | ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt, 20 | min_img_size=cfg.min_img_size 21 | ) 22 | self.cat_ids_as_labels = True # this is the default for original TF EfficientDet models 23 | self.coco = None 24 | self._load_annotations(cfg.ann_filename) 25 | 26 | def get_ann_info(self, idx): 27 | img_id = self.img_ids[idx] 28 | return self._parse_img_ann(img_id) 29 | 30 | def _load_annotations(self, ann_file): 31 | assert self.coco is None 32 | self.coco = COCO(ann_file) 33 | self.cat_ids = self.coco.getCatIds() 34 | self.cat_names = [c['name'] for c in self.coco.loadCats(ids=self.cat_ids)] 35 | if not self.cat_ids_as_labels: 36 | self.cat_id_to_label = {cat_id: i + self.label_offset for i, cat_id in enumerate(self.cat_ids)} 37 | img_ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) 38 | for img_id in sorted(self.coco.imgs.keys()): 39 | info = self.coco.loadImgs([img_id])[0] 40 | if (min(info['width'], info['height']) < self.min_img_size or 41 | (self.ignore_empty_gt and img_id not in img_ids_with_ann)): 42 | self.img_ids_invalid.append(img_id) 43 | continue 44 | self.img_ids.append(img_id) 45 | self.img_infos.append(info) 46 | 47 | def _parse_img_ann(self, img_id): 48 | ann_ids = self.coco.getAnnIds(imgIds=[img_id]) 49 | ann_info = self.coco.loadAnns(ann_ids) 50 | bboxes = [] 51 | bboxes_ignore = [] 52 | cls = [] 53 | 54 | for i, ann in enumerate(ann_info): 55 | if ann.get('ignore', False): 56 | continue 57 | x1, y1, w, h = ann['bbox'] 58 | if self.include_masks and ann['area'] <= 0: 59 | continue 60 | if w < 1 or h < 1: 61 | continue 62 | 63 | if self.yxyx: 64 | bbox = [y1, x1, y1 + h, x1 + w] 65 | else: 66 | bbox = [x1, y1, x1 + w, y1 + h] 67 | 68 | if ann.get('iscrowd', False): 69 | if self.include_bboxes_ignore: 70 | bboxes_ignore.append(bbox) 71 | else: 72 | bboxes.append(bbox) 73 | cls.append(self.cat_id_to_label[ann['category_id']] if self.cat_id_to_label else ann['category_id']) 74 | 75 | if bboxes: 76 | bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) 77 | cls = np.array(cls, dtype=np.int64) 78 | else: 79 | bboxes = np.zeros((0, 4), dtype=np.float32) 80 | cls = np.array([], dtype=np.int64) 81 | 82 | if self.include_bboxes_ignore: 83 | if bboxes_ignore: 84 | bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) 85 | else: 86 | bboxes_ignore = np.zeros((0, 4), dtype=np.float32) 87 | 88 | ann = dict(bbox=bboxes, cls=cls) 89 | 90 | if self.include_bboxes_ignore: 91 | ann['bbox_ignore'] = bboxes_ignore 92 | 93 | return ann 94 | -------------------------------------------------------------------------------- /effdet/data/parsers/parser_config.py: -------------------------------------------------------------------------------- 1 | """ Dataset parser configs 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | from dataclasses import dataclass 6 | 7 | __all__ = ['CocoParserCfg', 'OpenImagesParserCfg', 'VocParserCfg'] 8 | 9 | 10 | @dataclass 11 | class CocoParserCfg: 12 | ann_filename: str # absolute path 13 | include_masks: bool = False 14 | include_bboxes_ignore: bool = False 15 | has_labels: bool = True 16 | bbox_yxyx: bool = True 17 | min_img_size: int = 32 18 | ignore_empty_gt: bool = False 19 | 20 | 21 | @dataclass 22 | class VocParserCfg: 23 | split_filename: str 24 | ann_filename: str 25 | img_filename: str = '%.jpg' 26 | keep_difficult: bool = True 27 | classes: list = None 28 | add_background: bool = True 29 | has_labels: bool = True 30 | bbox_yxyx: bool = True 31 | min_img_size: int = 32 32 | ignore_empty_gt: bool = False 33 | 34 | 35 | @dataclass 36 | class OpenImagesParserCfg: 37 | categories_filename: str 38 | img_info_filename: str 39 | bbox_filename: str 40 | img_label_filename: str = '' 41 | masks_filename: str = '' 42 | img_filename: str = '%s.jpg' # relative to dataset img_dir 43 | task: str = 'obj' 44 | prefix_levels: int = 1 45 | add_background: bool = True 46 | has_labels: bool = True 47 | bbox_yxyx: bool = True 48 | min_img_size: int = 32 49 | ignore_empty_gt: bool = False 50 | -------------------------------------------------------------------------------- /effdet/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | """ Parser factory 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | from .parser_coco import CocoParser 6 | from .parser_voc import VocParser 7 | from .parser_open_images import OpenImagesParser 8 | 9 | 10 | def create_parser(name, **kwargs): 11 | if name == 'coco': 12 | parser = CocoParser(**kwargs) 13 | elif name == 'voc': 14 | parser = VocParser(**kwargs) 15 | elif name == 'openimages': 16 | parser = OpenImagesParser(**kwargs) 17 | else: 18 | assert False, f'Unknown dataset parser ({name})' 19 | return parser 20 | -------------------------------------------------------------------------------- /effdet/data/parsers/parser_open_images.py: -------------------------------------------------------------------------------- 1 | """ OpenImages dataset parser 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import numpy as np 6 | import os 7 | import logging 8 | 9 | from .parser import Parser 10 | from .parser_config import OpenImagesParserCfg 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | class OpenImagesParser(Parser): 16 | 17 | def __init__(self, cfg: OpenImagesParserCfg): 18 | super().__init__( 19 | bbox_yxyx=cfg.bbox_yxyx, 20 | has_labels=cfg.has_labels, 21 | include_masks=False, # FIXME to support someday 22 | include_bboxes_ignore=False, 23 | ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt, 24 | min_img_size=cfg.min_img_size 25 | ) 26 | self.img_prefix_levels = cfg.prefix_levels 27 | self.mask_prefix_levels = 1 28 | self._anns = None # access via get_ann_info() 29 | self._img_to_ann = None 30 | self._load_annotations( 31 | categories_filename=cfg.categories_filename, 32 | img_info_filename=cfg.img_info_filename, 33 | img_filename=cfg.img_filename, 34 | masks_filename=cfg.masks_filename, 35 | bbox_filename=cfg.bbox_filename 36 | ) 37 | 38 | def _load_annotations( 39 | self, 40 | categories_filename: str, 41 | img_info_filename: str, 42 | img_filename: str, 43 | masks_filename: str, 44 | bbox_filename: str, 45 | ): 46 | import pandas as pd # For now, blow up on pandas req only when trying to load open images anno 47 | 48 | _logger.info('Loading categories...') 49 | classes_df = pd.read_csv(categories_filename, header=None) 50 | self.cat_ids = classes_df[0].tolist() 51 | self.cat_names = classes_df[1].tolist() 52 | self.cat_id_to_label = {c: i + self.label_offset for i, c in enumerate(self.cat_ids)} 53 | 54 | def _img_filename(img_id): 55 | # build image filenames that are relative to img_dir 56 | filename = img_filename % img_id 57 | if self.img_prefix_levels: 58 | levels = [c for c in img_id[:self.img_prefix_levels]] 59 | filename = os.path.join(*levels, filename) 60 | return filename 61 | 62 | def _mask_filename(mask_path): 63 | # FIXME finish 64 | if self.mask_prefix_levels: 65 | levels = [c for c in mask_path[:self.mask_prefix_levels]] 66 | mask_path = os.path.join(*levels, mask_path) 67 | return mask_path 68 | 69 | def _load_img_info(csv_file, select_img_ids=None): 70 | _logger.info('Read img_info csv...') 71 | img_info_df = pd.read_csv(csv_file, index_col='id') 72 | 73 | _logger.info('Filter images...') 74 | if select_img_ids is not None: 75 | img_info_df = img_info_df.loc[select_img_ids] 76 | img_info_df = img_info_df[ 77 | (img_info_df['width'] >= self.min_img_size) & (img_info_df['height'] >= self.min_img_size)] 78 | 79 | _logger.info('Mapping ids...') 80 | img_info_df['img_id'] = img_info_df.index 81 | img_info_df['file_name'] = img_info_df.index.map(lambda x: _img_filename(x)) 82 | img_info_df = img_info_df[['img_id', 'file_name', 'width', 'height']] 83 | img_sizes = img_info_df[['width', 'height']].values 84 | self.img_infos = img_info_df.to_dict('records') 85 | self.img_ids = img_info_df.index.values.tolist() 86 | img_id_to_idx = {img_id: idx for idx, img_id in enumerate(self.img_ids)} 87 | return img_sizes, img_id_to_idx 88 | 89 | if self.include_masks and self.has_labels: 90 | masks_df = pd.read_csv(masks_filename) 91 | 92 | # NOTE currently using dataset masks anno ImageIDs to form valid img_ids from the dataset 93 | anno_img_ids = sorted(masks_df['ImageID'].unique()) 94 | img_sizes, img_id_to_idx = _load_img_info(img_info_filename, select_img_ids=anno_img_ids) 95 | 96 | masks_df['ImageIdx'] = masks_df['ImageID'].map(img_id_to_idx) 97 | if np.issubdtype(masks_df.ImageIdx.dtype, np.floating): 98 | masks_df = masks_df.dropna(axis='rows') 99 | masks_df['ImageIdx'] = masks_df.ImageIdx.astype(np.int32) 100 | masks_df.sort_values('ImageIdx', inplace=True) 101 | ann_img_idx = masks_df['ImageIdx'].values 102 | img_sizes = img_sizes[ann_img_idx] 103 | masks_df['BoxXMin'] = masks_df['BoxXMin'] * img_sizes[:, 0] 104 | masks_df['BoxXMax'] = masks_df['BoxXMax'] * img_sizes[:, 0] 105 | masks_df['BoxYMin'] = masks_df['BoxYMin'] * img_sizes[:, 1] 106 | masks_df['BoxYMax'] = masks_df['BoxYMax'] * img_sizes[:, 1] 107 | masks_df['LabelIdx'] = masks_df['LabelName'].map(self.cat_id_to_label) 108 | # FIXME remap mask filename with _mask_filename 109 | 110 | self._anns = dict( 111 | bbox=masks_df[['BoxXMin', 'BoxYMin', 'BoxXMax', 'BoxYMax']].values.astype(np.float32), 112 | label=masks_df[['LabelIdx']].values.astype(np.int32), 113 | mask_path=masks_df[['MaskPath']].values 114 | ) 115 | _, ri, rc = np.unique(ann_img_idx, return_index=True, return_counts=True) 116 | self._img_to_ann = list(zip(ri, rc)) # index, count tuples 117 | elif self.has_labels: 118 | _logger.info('Loading bbox...') 119 | bbox_df = pd.read_csv(bbox_filename) 120 | 121 | # NOTE currently using dataset box anno ImageIDs to form valid img_ids from the larger dataset. 122 | # FIXME use *imagelabels.csv or imagelabels-boxable.csv for negative examples (without box?) 123 | anno_img_ids = sorted(bbox_df['ImageID'].unique()) 124 | img_sizes, img_id_to_idx = _load_img_info(img_info_filename, select_img_ids=anno_img_ids) 125 | 126 | _logger.info('Process bbox...') 127 | bbox_df['ImageIdx'] = bbox_df['ImageID'].map(img_id_to_idx) 128 | if np.issubdtype(bbox_df.ImageIdx.dtype, np.floating): 129 | bbox_df = bbox_df.dropna(axis='rows') 130 | bbox_df['ImageIdx'] = bbox_df.ImageIdx.astype(np.int32) 131 | bbox_df.sort_values('ImageIdx', inplace=True) 132 | ann_img_idx = bbox_df['ImageIdx'].values 133 | img_sizes = img_sizes[ann_img_idx] 134 | bbox_df['XMin'] = bbox_df['XMin'] * img_sizes[:, 0] 135 | bbox_df['XMax'] = bbox_df['XMax'] * img_sizes[:, 0] 136 | bbox_df['YMin'] = bbox_df['YMin'] * img_sizes[:, 1] 137 | bbox_df['YMax'] = bbox_df['YMax'] * img_sizes[:, 1] 138 | bbox_df['LabelIdx'] = bbox_df['LabelName'].map(self.cat_id_to_label).astype(np.int32) 139 | 140 | self._anns = dict( 141 | bbox=bbox_df[['XMin', 'YMin', 'XMax', 'YMax']].values.astype(np.float32), 142 | label=bbox_df[['LabelIdx', 'IsGroupOf']].values.astype(np.int32), 143 | ) 144 | _, ri, rc = np.unique(ann_img_idx, return_index=True, return_counts=True) 145 | self._img_to_ann = list(zip(ri, rc)) # index, count tuples 146 | else: 147 | _load_img_info(img_info_filename) 148 | 149 | _logger.info('Annotations loaded!') 150 | 151 | def get_ann_info(self, idx): 152 | if not self.has_labels: 153 | return dict() 154 | start_idx, num_ann = self._img_to_ann[idx] 155 | ann_keys = tuple(self._anns.keys()) 156 | ann_values = tuple(self._anns[k][start_idx:start_idx + num_ann] for k in ann_keys) 157 | return self._parse_ann_info(idx, ann_keys, ann_values) 158 | 159 | def _parse_ann_info(self, img_idx, ann_keys, ann_values): 160 | """ 161 | """ 162 | gt_bboxes = [] 163 | gt_labels = [] 164 | gt_bboxes_ignore = [] 165 | if self.include_masks: 166 | assert 'mask_path' in ann_keys 167 | gt_masks = [] 168 | 169 | for ann in zip(*ann_values): 170 | ann = dict(zip(ann_keys, ann)) 171 | x1, y1, x2, y2 = ann['bbox'] 172 | if x2 - x1 < 1 or y2 - y1 < 1: 173 | continue 174 | label = ann['label'][0] 175 | iscrowd = False 176 | if len(ann['label']) > 1: 177 | iscrowd = ann['label'][1] 178 | if self.yxyx: 179 | bbox = np.array([y1, x1, y2, x2], dtype=np.float32) 180 | else: 181 | bbox = ann['bbox'] 182 | if iscrowd: 183 | gt_bboxes_ignore.append(bbox) 184 | else: 185 | gt_bboxes.append(bbox) 186 | gt_labels.append(label) 187 | # if self.include_masks: 188 | # img_info = self.img_infos[img_idx] 189 | # mask_img = SegmentationMask(ann['mask_filename'], img_info['width'], img_info['height']) 190 | # gt_masks.append(mask_img) 191 | 192 | if gt_bboxes: 193 | gt_bboxes = np.array(gt_bboxes, ndmin=2, dtype=np.float32) 194 | gt_labels = np.array(gt_labels, dtype=np.int64) 195 | else: 196 | gt_bboxes = np.zeros((0, 4), dtype=np.float32) 197 | gt_labels = np.array([], dtype=np.int64) 198 | 199 | if self.include_bboxes_ignore: 200 | if gt_bboxes_ignore: 201 | gt_bboxes_ignore = np.array(gt_bboxes_ignore, ndmin=2, dtype=np.float32) 202 | else: 203 | gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) 204 | 205 | ann = dict(bbox=gt_bboxes, cls=gt_labels) 206 | 207 | if self.include_bboxes_ignore: 208 | ann.update(dict(bbox_ignore=gt_bboxes_ignore, cls_ignore=np.array([], dtype=np.int64))) 209 | if self.include_masks: 210 | ann['masks'] = gt_masks 211 | return ann 212 | -------------------------------------------------------------------------------- /effdet/data/parsers/parser_voc.py: -------------------------------------------------------------------------------- 1 | """ Pascal VOC dataset parser 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | import xml.etree.ElementTree as ET 7 | from collections import defaultdict 8 | import numpy as np 9 | 10 | from .parser import Parser 11 | from .parser_config import VocParserCfg 12 | 13 | 14 | class VocParser(Parser): 15 | 16 | DEFAULT_CLASSES = ( 17 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 18 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 19 | 'sheep', 'sofa', 'train', 'tvmonitor') 20 | 21 | def __init__(self, cfg: VocParserCfg): 22 | super().__init__( 23 | bbox_yxyx=cfg.bbox_yxyx, 24 | has_labels=cfg.has_labels, 25 | include_masks=False, # FIXME to support someday 26 | include_bboxes_ignore=False, 27 | ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt, 28 | min_img_size=cfg.min_img_size 29 | ) 30 | self.correct_bbox = 1 31 | self.keep_difficult = cfg.keep_difficult 32 | 33 | self.anns = None 34 | self.img_id_to_idx = {} 35 | self._load_annotations( 36 | split_filename=cfg.split_filename, 37 | img_filename=cfg.img_filename, 38 | ann_filename=cfg.ann_filename, 39 | classes=cfg.classes, 40 | ) 41 | 42 | def _load_annotations( 43 | self, 44 | split_filename: str, 45 | img_filename: str, 46 | ann_filename: str, 47 | classes=None, 48 | ): 49 | classes = classes or self.DEFAULT_CLASSES 50 | self.cat_names = list(classes) 51 | self.cat_ids = self.cat_names 52 | self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)} 53 | 54 | self.anns = [] 55 | 56 | with open(split_filename) as f: 57 | ids = f.readlines() 58 | for img_id in ids: 59 | img_id = img_id.strip("\n") 60 | filename = img_filename % img_id 61 | xml_path = ann_filename % img_id 62 | tree = ET.parse(xml_path) 63 | root = tree.getroot() 64 | size = root.find('size') 65 | width = int(size.find('width').text) 66 | height = int(size.find('height').text) 67 | if min(width, height) < self.min_img_size: 68 | continue 69 | 70 | anns = [] 71 | for obj_idx, obj in enumerate(root.findall('object')): 72 | name = obj.find('name').text 73 | label = self.cat_id_to_label[name] 74 | difficult = int(obj.find('difficult').text) 75 | bnd_box = obj.find('bndbox') 76 | bbox = [ 77 | int(bnd_box.find('xmin').text), 78 | int(bnd_box.find('ymin').text), 79 | int(bnd_box.find('xmax').text), 80 | int(bnd_box.find('ymax').text) 81 | ] 82 | anns.append(dict(label=label, bbox=bbox, difficult=difficult)) 83 | 84 | if not self.ignore_empty_gt or len(anns): 85 | self.anns.append(anns) 86 | self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height)) 87 | self.img_ids.append(img_id) 88 | else: 89 | self.img_ids_invalid.append(img_id) 90 | 91 | def merge(self, other): 92 | assert len(self.cat_ids) == len(other.cat_ids) 93 | self.img_ids.extend(other.img_ids) 94 | self.img_infos.extend(other.img_infos) 95 | self.anns.extend(other.anns) 96 | 97 | def get_ann_info(self, idx): 98 | return self._parse_ann_info(self.anns[idx]) 99 | 100 | def _parse_ann_info(self, ann_info): 101 | bboxes = [] 102 | labels = [] 103 | bboxes_ignore = [] 104 | labels_ignore = [] 105 | for ann in ann_info: 106 | ignore = False 107 | x1, y1, x2, y2 = ann['bbox'] 108 | label = ann['label'] 109 | w = x2 - x1 110 | h = y2 - y1 111 | if w < 1 or h < 1: 112 | ignore = True 113 | if self.yxyx: 114 | bbox = [y1, x1, y2, x2] 115 | else: 116 | bbox = ann['bbox'] 117 | if ignore or (ann['difficult'] and not self.keep_difficult): 118 | bboxes_ignore.append(bbox) 119 | labels_ignore.append(label) 120 | else: 121 | bboxes.append(bbox) 122 | labels.append(label) 123 | 124 | if not bboxes: 125 | bboxes = np.zeros((0, 4), dtype=np.float32) 126 | labels = np.zeros((0, ), dtype=np.float32) 127 | else: 128 | bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox 129 | labels = np.array(labels, dtype=np.float32) 130 | 131 | if self.include_bboxes_ignore: 132 | if not bboxes_ignore: 133 | bboxes_ignore = np.zeros((0, 4), dtype=np.float32) 134 | labels_ignore = np.zeros((0, ), dtype=np.float32) 135 | else: 136 | bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox 137 | labels_ignore = np.array(labels_ignore, dtype=np.float32) 138 | 139 | ann = dict( 140 | bbox=bboxes.astype(np.float32), 141 | cls=labels.astype(np.int64)) 142 | 143 | if self.include_bboxes_ignore: 144 | ann.update(dict( 145 | bbox_ignore=bboxes_ignore.astype(np.float32), 146 | cls_ignore=labels_ignore.astype(np.int64))) 147 | return ann 148 | 149 | -------------------------------------------------------------------------------- /effdet/data/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ Multi-Scale RandomErasing 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import random 6 | import math 7 | import torch 8 | 9 | 10 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 11 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 12 | # paths, flip the order so normal is run on CPU if this becomes a problem 13 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 14 | if per_pixel: 15 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 16 | elif rand_color: 17 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 18 | else: 19 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 20 | 21 | 22 | class RandomErasing: 23 | """ Randomly selects a rectangle region in an image and erases its pixels. 24 | 'Random Erasing Data Augmentation' by Zhong et al. 25 | See https://arxiv.org/pdf/1708.04896.pdf 26 | 27 | This variant of RandomErasing is tweaked for multi-scale obj detection training. 28 | Args: 29 | probability: Probability that the Random Erasing operation will be performed. 30 | min_area: Minimum percentage of erased area wrt input image area. 31 | max_area: Maximum percentage of erased area wrt input image area. 32 | min_aspect: Minimum aspect ratio of erased area. 33 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 34 | 'const' - erase block is constant color of 0 for all channels 35 | 'rand' - erase block is same per-channel random (normal) color 36 | 'pixel' - erase block is per-pixel random (normal) color 37 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 38 | per-image count is randomly chosen between 1 and this value. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | probability=0.5, min_area=0.02, max_area=1/4, min_aspect=0.3, max_aspect=None, 44 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 45 | self.probability = probability 46 | self.min_area = min_area 47 | self.max_area = max_area 48 | max_aspect = max_aspect or 1 / min_aspect 49 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 50 | self.min_count = min_count 51 | self.max_count = max_count or min_count 52 | self.num_splits = num_splits 53 | mode = mode.lower() 54 | self.rand_color = False 55 | self.per_pixel = False 56 | if mode == 'rand': 57 | self.rand_color = True # per block random normal 58 | elif mode == 'pixel': 59 | self.per_pixel = True # per pixel random normal 60 | else: 61 | assert not mode or mode == 'const' 62 | self.device = device 63 | 64 | def _erase(self, img, chan, img_h, img_w, dtype): 65 | if random.random() > self.probability: 66 | return 67 | area = img_h * img_w 68 | count = self.min_count if self.min_count == self.max_count else \ 69 | random.randint(self.min_count, self.max_count) 70 | for _ in range(count): 71 | for attempt in range(10): 72 | target_area = random.uniform(self.min_area, self.max_area) * area / count 73 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 74 | h = int(round(math.sqrt(target_area * aspect_ratio))) 75 | w = int(round(math.sqrt(target_area / aspect_ratio))) 76 | if w < img_w and h < img_h: 77 | top = random.randint(0, img_h - h) 78 | left = random.randint(0, img_w - w) 79 | img[:, top:top + h, left:left + w] = _get_pixels( 80 | self.per_pixel, self.rand_color, (chan, h, w), 81 | dtype=dtype, device=self.device) 82 | break 83 | 84 | def __call__(self, input, target): 85 | batch_size, chan, input_h, input_w = input.shape 86 | img_scales = target['img_scale'] 87 | img_size = (target['img_size'] / img_scales.unsqueeze(1)).int() 88 | img_size[:, 0] = img_size[:, 0].clamp(max=input_w) 89 | img_size[:, 1] = img_size[:, 1].clamp(max=input_h) 90 | # skip first slice of batch if num_splits is set (for clean portion of samples) 91 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 92 | for i in range(batch_start, batch_size): 93 | self._erase(input[i], chan, img_size[i, 1], img_size[i, 0], input.dtype) 94 | return input 95 | -------------------------------------------------------------------------------- /effdet/data/transforms.py: -------------------------------------------------------------------------------- 1 | """ COCO transforms (quick and dirty) 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import random 6 | import math 7 | from copy import deepcopy 8 | 9 | from PIL import Image 10 | import numpy as np 11 | import torch 12 | 13 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 14 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 15 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 16 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 17 | 18 | 19 | class ImageToNumpy: 20 | 21 | def __call__(self, pil_img, annotations: dict): 22 | np_img = np.array(pil_img, dtype=np.uint8) 23 | if np_img.ndim < 3: 24 | np_img = np.expand_dims(np_img, axis=-1) 25 | np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW 26 | return np_img, annotations 27 | 28 | 29 | class ImageToTensor: 30 | 31 | def __init__(self, dtype=torch.float32): 32 | self.dtype = dtype 33 | 34 | def __call__(self, pil_img, annotations: dict): 35 | np_img = np.array(pil_img, dtype=np.uint8) 36 | if np_img.ndim < 3: 37 | np_img = np.expand_dims(np_img, axis=-1) 38 | np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW 39 | return torch.from_numpy(np_img).to(dtype=self.dtype), annotations 40 | 41 | 42 | def _pil_interp(method): 43 | if method == 'bicubic': 44 | return Image.BICUBIC 45 | elif method == 'lanczos': 46 | return Image.LANCZOS 47 | elif method == 'hamming': 48 | return Image.HAMMING 49 | else: 50 | # default bilinear, do we want to allow nearest? 51 | return Image.BILINEAR 52 | 53 | 54 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 55 | 56 | 57 | def clip_boxes_(boxes, img_size): 58 | height, width = img_size 59 | clip_upper = np.array([height, width] * 2, dtype=boxes.dtype) 60 | np.clip(boxes, 0, clip_upper, out=boxes) 61 | 62 | 63 | def clip_boxes(boxes, img_size): 64 | clipped_boxes = boxes.copy() 65 | clip_boxes_(clipped_boxes, img_size) 66 | return clipped_boxes 67 | 68 | 69 | def _size_tuple(size): 70 | if isinstance(size, int): 71 | return size, size 72 | else: 73 | assert len(size) == 2 74 | return size 75 | 76 | 77 | class ResizePad: 78 | 79 | def __init__(self, target_size: int, interpolation: str = 'bilinear', fill_color: tuple = (0, 0, 0)): 80 | self.target_size = _size_tuple(target_size) 81 | self.interpolation = interpolation 82 | self.fill_color = fill_color 83 | 84 | def __call__(self, img, anno: dict): 85 | width, height = img.size 86 | 87 | img_scale_y = self.target_size[0] / height 88 | img_scale_x = self.target_size[1] / width 89 | img_scale = min(img_scale_y, img_scale_x) 90 | scaled_h = int(height * img_scale) 91 | scaled_w = int(width * img_scale) 92 | 93 | new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color) 94 | interp_method = _pil_interp(self.interpolation) 95 | img = img.resize((scaled_w, scaled_h), interp_method) 96 | new_img.paste(img) # pastes at 0,0 (upper-left corner) 97 | 98 | if 'bbox' in anno: 99 | bbox = anno['bbox'] 100 | bbox[:, :4] *= img_scale 101 | bbox_bound = (min(scaled_h, self.target_size[0]), min(scaled_w, self.target_size[1])) 102 | clip_boxes_(bbox, bbox_bound) # crop to bounds of target image or letter-box, whichever is smaller 103 | valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1) 104 | anno['bbox'] = bbox[valid_indices, :] 105 | anno['cls'] = anno['cls'][valid_indices] 106 | 107 | anno['img_scale'] = 1. / img_scale # back to original 108 | 109 | return new_img, anno 110 | 111 | 112 | class RandomResizePad: 113 | 114 | def __init__(self, target_size: int, scale: tuple = (0.1, 2.0), interpolation: str = 'random', 115 | fill_color: tuple = (0, 0, 0)): 116 | self.target_size = _size_tuple(target_size) 117 | self.scale = scale 118 | if interpolation == 'random': 119 | self.interpolation = _RANDOM_INTERPOLATION 120 | else: 121 | self.interpolation = _pil_interp(interpolation) 122 | self.fill_color = fill_color 123 | 124 | def _get_params(self, img): 125 | # Select a random scale factor. 126 | scale_factor = random.uniform(*self.scale) 127 | scaled_target_height = scale_factor * self.target_size[0] 128 | scaled_target_width = scale_factor * self.target_size[1] 129 | 130 | # Recompute the accurate scale_factor using rounded scaled image size. 131 | width, height = img.size 132 | img_scale_y = scaled_target_height / height 133 | img_scale_x = scaled_target_width / width 134 | img_scale = min(img_scale_y, img_scale_x) 135 | 136 | # Select non-zero random offset (x, y) if scaled image is larger than target size 137 | scaled_h = int(height * img_scale) 138 | scaled_w = int(width * img_scale) 139 | offset_y = scaled_h - self.target_size[0] 140 | offset_x = scaled_w - self.target_size[1] 141 | offset_y = int(max(0.0, float(offset_y)) * random.uniform(0, 1)) 142 | offset_x = int(max(0.0, float(offset_x)) * random.uniform(0, 1)) 143 | return scaled_h, scaled_w, offset_y, offset_x, img_scale 144 | 145 | def __call__(self, img, anno: dict): 146 | scaled_h, scaled_w, offset_y, offset_x, img_scale = self._get_params(img) 147 | 148 | if isinstance(self.interpolation, (tuple, list)): 149 | interpolation = random.choice(self.interpolation) 150 | else: 151 | interpolation = self.interpolation 152 | img = img.resize((scaled_w, scaled_h), interpolation) 153 | right, lower = min(scaled_w, offset_x + self.target_size[1]), min(scaled_h, offset_y + self.target_size[0]) 154 | img = img.crop((offset_x, offset_y, right, lower)) 155 | new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color) 156 | new_img.paste(img) # pastes at 0,0 (upper-left corner) 157 | 158 | if 'bbox' in anno: 159 | bbox = anno['bbox'] # for convenience, modifies in-place 160 | bbox[:, :4] *= img_scale 161 | box_offset = np.stack([offset_y, offset_x] * 2) 162 | bbox -= box_offset 163 | bbox_bound = (min(scaled_h, self.target_size[0]), min(scaled_w, self.target_size[1])) 164 | clip_boxes_(bbox, bbox_bound) # crop to bounds of target image or letter-box, whichever is smaller 165 | valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1) 166 | anno['bbox'] = bbox[valid_indices, :] 167 | anno['cls'] = anno['cls'][valid_indices] 168 | 169 | anno['img_scale'] = 1. / img_scale # back to original 170 | 171 | return new_img, anno 172 | 173 | 174 | class RandomFlip: 175 | 176 | def __init__(self, horizontal=True, vertical=False, prob=0.5): 177 | self.horizontal = horizontal 178 | self.vertical = vertical 179 | self.prob = prob 180 | 181 | def _get_params(self): 182 | do_horizontal = random.random() < self.prob if self.horizontal else False 183 | do_vertical = random.random() < self.prob if self.vertical else False 184 | return do_horizontal, do_vertical 185 | 186 | def __call__(self, img, annotations: dict): 187 | do_horizontal, do_vertical = self._get_params() 188 | width, height = img.size 189 | 190 | def _fliph(bbox): 191 | x_max = width - bbox[:, 1] 192 | x_min = width - bbox[:, 3] 193 | bbox[:, 1] = x_min 194 | bbox[:, 3] = x_max 195 | 196 | def _flipv(bbox): 197 | y_max = height - bbox[:, 0] 198 | y_min = height - bbox[:, 2] 199 | bbox[:, 0] = y_min 200 | bbox[:, 2] = y_max 201 | 202 | if do_horizontal and do_vertical: 203 | img = img.transpose(Image.ROTATE_180) 204 | if 'bbox' in annotations: 205 | _fliph(annotations['bbox']) 206 | _flipv(annotations['bbox']) 207 | elif do_horizontal: 208 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 209 | if 'bbox' in annotations: 210 | _fliph(annotations['bbox']) 211 | elif do_vertical: 212 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 213 | if 'bbox' in annotations: 214 | _flipv(annotations['bbox']) 215 | 216 | return img, annotations 217 | 218 | 219 | def resolve_fill_color(fill_color, img_mean=IMAGENET_DEFAULT_MEAN): 220 | if isinstance(fill_color, tuple): 221 | assert len(fill_color) == 3 222 | fill_color = fill_color 223 | else: 224 | try: 225 | int_color = int(fill_color) 226 | fill_color = (int_color,) * 3 227 | except ValueError: 228 | assert fill_color == 'mean' 229 | fill_color = tuple([int(round(255 * x)) for x in img_mean]) 230 | return fill_color 231 | 232 | 233 | class Compose: 234 | 235 | def __init__(self, transforms: list): 236 | self.transforms = transforms 237 | 238 | def __call__(self, img, annotations: dict): 239 | for t in self.transforms: 240 | img, annotations = t(img, annotations) 241 | return img, annotations 242 | 243 | 244 | def transforms_coco_eval( 245 | img_size=224, 246 | interpolation='bilinear', 247 | use_prefetcher=False, 248 | fill_color='mean', 249 | mean=IMAGENET_DEFAULT_MEAN, 250 | std=IMAGENET_DEFAULT_STD): 251 | 252 | fill_color = resolve_fill_color(fill_color, mean) 253 | 254 | image_tfl = [ 255 | ResizePad( 256 | target_size=img_size, interpolation=interpolation, fill_color=fill_color), 257 | ImageToNumpy(), 258 | ] 259 | 260 | assert use_prefetcher, "Only supporting prefetcher usage right now" 261 | 262 | image_tf = Compose(image_tfl) 263 | return image_tf 264 | 265 | 266 | def transforms_coco_train( 267 | img_size=224, 268 | interpolation='random', 269 | use_prefetcher=False, 270 | fill_color='mean', 271 | mean=IMAGENET_DEFAULT_MEAN, 272 | std=IMAGENET_DEFAULT_STD): 273 | 274 | fill_color = resolve_fill_color(fill_color, mean) 275 | 276 | image_tfl = [ 277 | RandomFlip(horizontal=True, prob=0.5), 278 | RandomResizePad( 279 | target_size=img_size, interpolation=interpolation, fill_color=fill_color), 280 | ImageToNumpy(), 281 | ] 282 | 283 | assert use_prefetcher, "Only supporting prefetcher usage right now" 284 | 285 | image_tf = Compose(image_tfl) 286 | return image_tf 287 | -------------------------------------------------------------------------------- /effdet/distributed.py: -------------------------------------------------------------------------------- 1 | """ PyTorch distributed helpers 2 | 3 | Some of this lifted from Detectron2 with other fns added by myself. Some of the Detectron2 fns 4 | were intended for use with GLOO PG. I am using NCCL here with default PG so not everything will work 5 | as is -RW 6 | """ 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | _LOCAL_PROCESS_GROUP = None 15 | """ 16 | A torch process group which only includes processes that on the same machine as the current process. 17 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 18 | """ 19 | 20 | 21 | def get_world_size() -> int: 22 | if not dist.is_available(): 23 | return 1 24 | if not dist.is_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | 29 | def get_rank() -> int: 30 | if not dist.is_available(): 31 | return 0 32 | if not dist.is_initialized(): 33 | return 0 34 | return dist.get_rank() 35 | 36 | 37 | def get_local_rank() -> int: 38 | """ 39 | Returns: 40 | The rank of the current process within the local (per-machine) process group. 41 | """ 42 | if not dist.is_available(): 43 | return 0 44 | if not dist.is_initialized(): 45 | return 0 46 | assert _LOCAL_PROCESS_GROUP is not None 47 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 48 | 49 | 50 | def get_local_size() -> int: 51 | """ 52 | Returns: 53 | The size of the per-machine process group, 54 | i.e. the number of processes per machine. 55 | """ 56 | if not dist.is_available(): 57 | return 1 58 | if not dist.is_initialized(): 59 | return 1 60 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 61 | 62 | 63 | def is_main_process() -> bool: 64 | return get_rank() == 0 65 | 66 | 67 | def synchronize(): 68 | """ 69 | Helper function to synchronize (barrier) among all processes when 70 | using distributed training 71 | """ 72 | if not dist.is_available(): 73 | return 74 | if not dist.is_initialized(): 75 | return 76 | world_size = dist.get_world_size() 77 | if world_size == 1: 78 | return 79 | dist.barrier() 80 | 81 | 82 | @functools.lru_cache() 83 | def _get_global_gloo_group(): 84 | """ 85 | Return a process group based on gloo backend, containing all the ranks 86 | The result is cached. 87 | """ 88 | if dist.get_backend() == "nccl": 89 | return dist.new_group(backend="gloo") 90 | else: 91 | return dist.group.WORLD 92 | 93 | 94 | def _serialize_to_tensor(data, group): 95 | backend = dist.get_backend(group) 96 | assert backend in ["gloo", "nccl"] 97 | device = torch.device("cpu" if backend == "gloo" else "cuda") 98 | 99 | buffer = pickle.dumps(data) 100 | if len(buffer) > 1024 ** 3: 101 | logger = logging.getLogger(__name__) 102 | logger.warning( 103 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 104 | get_rank(), len(buffer) / (1024 ** 3), device 105 | ) 106 | ) 107 | storage = torch.ByteStorage.from_buffer(buffer) 108 | tensor = torch.ByteTensor(storage).to(device=device) 109 | return tensor 110 | 111 | 112 | def _pad_to_largest_tensor(tensor, group): 113 | """ 114 | Returns: 115 | list[int]: size of the tensor, on each rank 116 | Tensor: padded tensor that has the max size 117 | """ 118 | world_size = dist.get_world_size(group=group) 119 | assert ( 120 | world_size >= 1 121 | ), "comm.gather/all_gather must be called from ranks within the given group!" 122 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 123 | size_list = [ 124 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 125 | ] 126 | dist.all_gather(size_list, local_size, group=group) 127 | size_list = [int(size.item()) for size in size_list] 128 | 129 | max_size = max(size_list) 130 | 131 | # we pad the tensor because torch all_gather does not support 132 | # gathering tensors of different shapes 133 | if local_size != max_size: 134 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 135 | tensor = torch.cat((tensor, padding), dim=0) 136 | return size_list, tensor 137 | 138 | 139 | def all_gather(data, group=None): 140 | """ 141 | Run all_gather on arbitrary picklable data (not necessarily tensors). 142 | Args: 143 | data: any picklable object 144 | group: a torch process group. By default, will use a group which 145 | contains all ranks on gloo backend. 146 | Returns: 147 | list[data]: list of data gathered from each rank 148 | """ 149 | if get_world_size() == 1: 150 | return [data] 151 | if group is None: 152 | group = _get_global_gloo_group() 153 | if dist.get_world_size(group) == 1: 154 | return [data] 155 | 156 | tensor = _serialize_to_tensor(data, group) 157 | 158 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 159 | max_size = max(size_list) 160 | 161 | # receiving Tensor from all ranks 162 | tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list] 163 | dist.all_gather(tensor_list, tensor, group=group) 164 | 165 | data_list = [] 166 | for size, tensor in zip(size_list, tensor_list): 167 | buffer = tensor.cpu().numpy().tobytes()[:size] 168 | data_list.append(pickle.loads(buffer)) 169 | 170 | return data_list 171 | 172 | 173 | def gather(data, dst=0, group=None): 174 | """ 175 | Run gather on arbitrary picklable data (not necessarily tensors). 176 | Args: 177 | data: any picklable object 178 | dst (int): destination rank 179 | group: a torch process group. By default, will use a group which 180 | contains all ranks on gloo backend. 181 | Returns: 182 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 183 | an empty list. 184 | """ 185 | if get_world_size() == 1: 186 | return [data] 187 | if group is None: 188 | group = _get_global_gloo_group() 189 | if dist.get_world_size(group=group) == 1: 190 | return [data] 191 | rank = dist.get_rank(group=group) 192 | 193 | tensor = _serialize_to_tensor(data, group) 194 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 195 | 196 | # receiving Tensor from all ranks 197 | if rank == dst: 198 | max_size = max(size_list) 199 | tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list] 200 | dist.gather(tensor, tensor_list, dst=dst, group=group) 201 | 202 | data_list = [] 203 | for size, tensor in zip(size_list, tensor_list): 204 | buffer = tensor.cpu().numpy().tobytes()[:size] 205 | data_list.append(pickle.loads(buffer)) 206 | return data_list 207 | else: 208 | dist.gather(tensor, [], dst=dst, group=group) 209 | return [] 210 | 211 | 212 | def shared_random_seed(): 213 | """ 214 | Returns: 215 | int: a random number that is the same across all workers. 216 | If workers need a shared RNG, they can use this shared seed to 217 | create one. 218 | All workers must call this function, otherwise it will deadlock. 219 | """ 220 | ints = np.random.randint(2 ** 31) 221 | all_ints = all_gather(ints) 222 | return all_ints[0] 223 | 224 | 225 | def reduce_dict(input_dict, average=True): 226 | """ 227 | Reduce the values in the dictionary from all processes so that process with rank 228 | 0 has the reduced results. 229 | Args: 230 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 231 | average (bool): whether to do average or sum 232 | Returns: 233 | a dict with the same keys as input_dict, after reduction. 234 | """ 235 | world_size = get_world_size() 236 | if world_size < 2: 237 | return input_dict 238 | with torch.no_grad(): 239 | names = [] 240 | values = [] 241 | # sort the keys so that they are consistent across processes 242 | for k in sorted(input_dict.keys()): 243 | names.append(k) 244 | values.append(input_dict[k]) 245 | values = torch.stack(values, dim=0) 246 | dist.reduce(values, dst=0) 247 | if dist.get_rank() == 0 and average: 248 | # only main process gets accumulated, so only divide by 249 | # world_size in this case 250 | values /= world_size 251 | reduced_dict = {k: v for k, v in zip(names, values)} 252 | return reduced_dict 253 | 254 | 255 | def all_gather_container(container, group=None, cat_dim=0): 256 | group = group or dist.group.WORLD 257 | world_size = dist.get_world_size(group) 258 | 259 | def _do_gather(tensor): 260 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 261 | dist.all_gather(tensor_list, tensor, group=group) 262 | return torch.cat(tensor_list, dim=cat_dim) 263 | 264 | if isinstance(container, dict): 265 | gathered = dict() 266 | for k, v in container.items(): 267 | v = _do_gather(v) 268 | gathered[k] = v 269 | return gathered 270 | elif isinstance(container, (list, tuple)): 271 | gathered = [_do_gather(v) for v in container] 272 | if isinstance(container, tuple): 273 | gathered = tuple(gathered) 274 | return gathered 275 | else: 276 | # if not a dict, list, tuple, expect a singular tensor 277 | assert isinstance(container, torch.Tensor) 278 | return _do_gather(container) 279 | 280 | 281 | def gather_container(container, dst, group=None, cat_dim=0): 282 | group = group or dist.group.WORLD 283 | world_size = dist.get_world_size(group) 284 | this_rank = dist.get_rank(group) 285 | 286 | def _do_gather(tensor): 287 | if this_rank == dst: 288 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 289 | else: 290 | tensor_list = None 291 | dist.gather(tensor, tensor_list, dst=dst, group=group) 292 | return torch.cat(tensor_list, dim=cat_dim) 293 | 294 | if isinstance(container, dict): 295 | gathered = dict() 296 | for k, v in container.items(): 297 | v = _do_gather(v) 298 | gathered[k] = v 299 | return gathered 300 | elif isinstance(container, (list, tuple)): 301 | gathered = [_do_gather(v) for v in container] 302 | if isinstance(container, tuple): 303 | gathered = tuple(gathered) 304 | return gathered 305 | else: 306 | # if not a dict, list, tuple, expect a singular tensor 307 | assert isinstance(container, torch.Tensor) 308 | return _do_gather(container) 309 | -------------------------------------------------------------------------------- /effdet/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Models Evaluation 2 | 3 | The code in this folder has been extracted and adapted from evaluation/evaluator code at https://github.com/tensorflow/models/tree/master/research/object_detection/utils 4 | 5 | Original code is licensed Apache 2.0, Copyright Google Inc. 6 | https://github.com/tensorflow/models/blob/master/LICENSE 7 | -------------------------------------------------------------------------------- /effdet/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rwightman/efficientdet-pytorch/c6dff775a36cea0bf9b76c58e59f936411c5ce01/effdet/evaluation/__init__.py -------------------------------------------------------------------------------- /effdet/evaluation/fields.py: -------------------------------------------------------------------------------- 1 | 2 | class InputDataFields(object): 3 | """Names for the input tensors. 4 | Holds the standard data field names to use for identifying input tensors. This 5 | should be used by the decoder to identify keys for the returned tensor_dict 6 | containing input tensors. And it should be used by the model to identify the 7 | tensors it needs. 8 | Attributes: 9 | image: image. 10 | image_additional_channels: additional channels. 11 | key: unique key corresponding to image. 12 | filename: original filename of the dataset (without common path). 13 | gt_image_classes: image-level class labels. 14 | gt_image_confidences: image-level class confidences. 15 | gt_labeled_classes: image-level annotation that indicates the 16 | classes for which an image has been labeled. 17 | gt_boxes: coordinates of the ground truth boxes in the image. 18 | gt_classes: box-level class labels. 19 | gt_confidences: box-level class confidences. The shape should be 20 | the same as the shape of gt_classes. 21 | gt_label_types: box-level label types (e.g. explicit negative). 22 | gt_is_crowd: [DEPRECATED, use gt_group_of instead] 23 | is the groundtruth a single object or a crowd. 24 | gt_area: area of a groundtruth segment. 25 | gt_difficult: is a `difficult` object 26 | gt_group_of: is a `group_of` objects, e.g. multiple objects of the 27 | same class, forming a connected group, where instances are heavily 28 | occluding each other. 29 | gt_instance_masks: ground truth instance masks. 30 | gt_instance_boundaries: ground truth instance boundaries. 31 | gt_instance_classes: instance mask-level class labels. 32 | gt_label_weights: groundtruth label weights. 33 | gt_weights: groundtruth weight factor for bounding boxes. 34 | image_height: height of images, used to decode 35 | image_width: width of images, used to decode 36 | """ 37 | image = 'image' 38 | key = 'image_id' 39 | filename = 'filename' 40 | gt_boxes = 'bbox' 41 | gt_classes = 'cls' 42 | gt_confidences = 'confidences' 43 | gt_label_types = 'label_types' 44 | gt_image_classes = 'img_cls' 45 | gt_image_confidences = 'img_confidences' 46 | gt_labeled_classes = 'labeled_cls' 47 | gt_is_crowd = 'is_crowd' 48 | gt_area = 'area' 49 | gt_difficult = 'difficult' 50 | gt_group_of = 'group_of' 51 | gt_instance_masks = 'instance_masks' 52 | gt_instance_boundaries = 'instance_boundaries' 53 | gt_instance_classes = 'instance_classes' 54 | image_height = 'img_height' 55 | image_width = 'img_width' 56 | image_size = 'img_size' 57 | 58 | 59 | class DetectionResultFields(object): 60 | """Naming conventions for storing the output of the detector. 61 | Attributes: 62 | source_id: source of the original image. 63 | key: unique key corresponding to image. 64 | detection_boxes: coordinates of the detection boxes in the image. 65 | detection_scores: detection scores for the detection boxes in the image. 66 | detection_multiclass_scores: class score distribution (including background) 67 | for detection boxes in the image including background class. 68 | detection_classes: detection-level class labels. 69 | detection_masks: contains a segmentation mask for each detection box. 70 | """ 71 | 72 | key = 'image_id' 73 | detection_boxes = 'bbox' 74 | detection_scores = 'score' 75 | detection_classes = 'cls' 76 | detection_masks = 'masks' 77 | 78 | 79 | class BoxListFields(object): 80 | """Naming conventions for BoxLists. 81 | Attributes: 82 | boxes: bounding box coordinates. 83 | classes: classes per bounding box. 84 | scores: scores per bounding box. 85 | weights: sample weights per bounding box. 86 | objectness: objectness score per bounding box. 87 | masks: masks per bounding box. 88 | boundaries: boundaries per bounding box. 89 | keypoints: keypoints per bounding box. 90 | keypoint_heatmaps: keypoint heatmaps per bounding box. 91 | is_crowd: is_crowd annotation per bounding box. 92 | """ 93 | boxes = 'boxes' 94 | classes = 'classes' 95 | scores = 'scores' 96 | weights = 'weights' 97 | confidences = 'confidences' 98 | objectness = 'objectness' 99 | masks = 'masks' 100 | boundaries = 'boundaries' 101 | keypoints = 'keypoints' 102 | keypoint_visibilities = 'keypoint_visibilities' 103 | keypoint_heatmaps = 'keypoint_heatmaps' 104 | is_crowd = 'is_crowd' 105 | group_of = 'group_of' 106 | -------------------------------------------------------------------------------- /effdet/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_precision_recall(scores, labels, num_gt): 5 | """Compute precision and recall. 6 | Args: 7 | scores: A float numpy array representing detection score 8 | labels: A float numpy array representing weighted true/false positive labels 9 | num_gt: Number of ground truth instances 10 | Raises: 11 | ValueError: if the input is not of the correct format 12 | Returns: 13 | precision: Fraction of positive instances over detected ones. This value is 14 | None if no ground truth labels are present. 15 | recall: Fraction of detected positive instance over all positive instances. 16 | This value is None if no ground truth labels are present. 17 | """ 18 | if not isinstance(labels, np.ndarray) or len(labels.shape) != 1: 19 | raise ValueError("labels must be single dimension numpy array") 20 | 21 | if labels.dtype != np.float_ and labels.dtype != np.bool_: 22 | raise ValueError("labels type must be either bool or float") 23 | 24 | if not isinstance(scores, np.ndarray) or len(scores.shape) != 1: 25 | raise ValueError("scores must be single dimension numpy array") 26 | 27 | if num_gt < np.sum(labels): 28 | raise ValueError("Number of true positives must be smaller than num_gt.") 29 | 30 | if len(scores) != len(labels): 31 | raise ValueError("scores and labels must be of the same size.") 32 | 33 | if num_gt == 0: 34 | return None, None 35 | 36 | sorted_indices = np.argsort(scores) 37 | sorted_indices = sorted_indices[::-1] 38 | true_positive_labels = labels[sorted_indices] 39 | false_positive_labels = (true_positive_labels <= 0).astype(float) 40 | cum_true_positives = np.cumsum(true_positive_labels) 41 | cum_false_positives = np.cumsum(false_positive_labels) 42 | precision = cum_true_positives.astype(float) / (cum_true_positives + cum_false_positives) 43 | recall = cum_true_positives.astype(float) / num_gt 44 | return precision, recall 45 | 46 | 47 | def compute_average_precision(precision, recall): 48 | """Compute Average Precision according to the definition in VOCdevkit. 49 | Precision is modified to ensure that it does not decrease as recall 50 | decrease. 51 | Args: 52 | precision: A float [N, 1] numpy array of precisions 53 | recall: A float [N, 1] numpy array of recalls 54 | Raises: 55 | ValueError: if the input is not of the correct format 56 | Returns: 57 | average_precison: The area under the precision recall curve. NaN if 58 | precision and recall are None. 59 | """ 60 | if precision is None: 61 | if recall is not None: 62 | raise ValueError("If precision is None, recall must also be None") 63 | return np.NAN 64 | 65 | if not isinstance(precision, np.ndarray) or not isinstance(recall, np.ndarray): 66 | raise ValueError("precision and recall must be numpy array") 67 | if precision.dtype != np.float_ or recall.dtype != np.float_: 68 | raise ValueError("input must be float numpy array.") 69 | if len(precision) != len(recall): 70 | raise ValueError("precision and recall must be of the same size.") 71 | if not precision.size: 72 | return 0.0 73 | if np.amin(precision) < 0 or np.amax(precision) > 1: 74 | raise ValueError("Precision must be in the range of [0, 1].") 75 | if np.amin(recall) < 0 or np.amax(recall) > 1: 76 | raise ValueError("recall must be in the range of [0, 1].") 77 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 78 | raise ValueError("recall must be a non-decreasing array") 79 | 80 | recall = np.concatenate([[0], recall, [1]]) 81 | precision = np.concatenate([[0], precision, [0]]) 82 | 83 | # Preprocess precision to be a non-decreasing array 84 | for i in range(len(precision) - 2, -1, -1): 85 | precision[i] = np.maximum(precision[i], precision[i + 1]) 86 | 87 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 88 | average_precision = np.sum((recall[indices] - recall[indices - 1]) * precision[indices]) 89 | return average_precision 90 | 91 | 92 | def compute_cor_loc(num_gt_imgs_per_class, num_images_correctly_detected_per_class): 93 | """Compute CorLoc according to the definition in the following paper. 94 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf 95 | Returns nans if there are no ground truth images for a class. 96 | Args: 97 | num_gt_imgs_per_class: 1D array, representing number of images containing 98 | at least one object instance of a particular class 99 | num_images_correctly_detected_per_class: 1D array, representing number of 100 | images that are correctly detected at least one object instance of a particular class 101 | Returns: 102 | corloc_per_class: A float numpy array represents the corloc score of each class 103 | """ 104 | return np.where( 105 | num_gt_imgs_per_class == 0, np.nan, 106 | num_images_correctly_detected_per_class / num_gt_imgs_per_class) 107 | 108 | 109 | def compute_median_rank_at_k(tp_fp_list, k): 110 | """Computes MedianRank@k, where k is the top-scoring labels. 111 | Args: 112 | tp_fp_list: a list of numpy arrays; each numpy array corresponds to the all 113 | detection on a single image, where the detections are sorted by score in 114 | descending order. Further, each numpy array element can have boolean or 115 | float values. True positive elements have either value >0.0 or True; 116 | any other value is considered false positive. 117 | k: number of top-scoring proposals to take. 118 | Returns: 119 | median_rank: median rank of all true positive proposals among top k by score. 120 | """ 121 | ranks = [] 122 | for i in range(len(tp_fp_list)): 123 | ranks.append(np.where(tp_fp_list[i][0:min(k, tp_fp_list[i].shape[0])] > 0)[0]) 124 | concatenated_ranks = np.concatenate(ranks) 125 | return np.median(concatenated_ranks) 126 | 127 | 128 | def compute_recall_at_k(tp_fp_list, num_gt, k): 129 | """Computes Recall@k, MedianRank@k, where k is the top-scoring labels. 130 | Args: 131 | tp_fp_list: a list of numpy arrays; each numpy array corresponds to the all 132 | detection on a single image, where the detections are sorted by score in 133 | descending order. Further, each numpy array element can have boolean or 134 | float values. True positive elements have either value >0.0 or True; 135 | any other value is considered false positive. 136 | num_gt: number of groundtruth anotations. 137 | k: number of top-scoring proposals to take. 138 | Returns: 139 | recall: recall evaluated on the top k by score detections. 140 | """ 141 | 142 | tp_fp_eval = [] 143 | for i in range(len(tp_fp_list)): 144 | tp_fp_eval.append(tp_fp_list[i][0:min(k, tp_fp_list[i].shape[0])]) 145 | 146 | tp_fp_eval = np.concatenate(tp_fp_eval) 147 | 148 | return np.sum(tp_fp_eval) / num_gt 149 | -------------------------------------------------------------------------------- /effdet/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import abc 3 | import json 4 | import logging 5 | import time 6 | from tempfile import NamedTemporaryFile 7 | 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | 12 | from pycocotools.cocoeval import COCOeval 13 | from .distributed import synchronize, is_main_process, all_gather_container 14 | 15 | # FIXME experimenting with speedups for OpenImages eval, it's slow 16 | #import pyximport; py_importer, pyx_importer = pyximport.install(pyimport=True) 17 | import effdet.evaluation.detection_evaluator as tfm_eval 18 | #pyximport.uninstall(py_importer, pyx_importer) 19 | 20 | _logger = logging.getLogger(__name__) 21 | 22 | 23 | __all__ = ['CocoEvaluator', 'PascalEvaluator', 'OpenImagesEvaluator', 'create_evaluator'] 24 | 25 | 26 | class Evaluator: 27 | 28 | def __init__(self, distributed=False, pred_yxyx=False): 29 | self.distributed = distributed 30 | self.distributed_device = None 31 | self.pred_yxyx = pred_yxyx 32 | self.img_indices = [] 33 | self.predictions = [] 34 | 35 | def add_predictions(self, detections, target): 36 | if self.distributed: 37 | if self.distributed_device is None: 38 | # cache for use later to broadcast end metric 39 | self.distributed_device = detections.device 40 | synchronize() 41 | detections = all_gather_container(detections) 42 | img_indices = all_gather_container(target['img_idx']) 43 | if not is_main_process(): 44 | return 45 | else: 46 | img_indices = target['img_idx'] 47 | 48 | detections = detections.cpu().numpy() 49 | img_indices = img_indices.cpu().numpy() 50 | for img_idx, img_dets in zip(img_indices, detections): 51 | self.img_indices.append(img_idx) 52 | self.predictions.append(img_dets) 53 | 54 | def _coco_predictions(self): 55 | # generate coco-style predictions 56 | coco_predictions = [] 57 | coco_ids = [] 58 | for img_idx, img_dets in zip(self.img_indices, self.predictions): 59 | img_id = self._dataset.img_ids[img_idx] 60 | coco_ids.append(img_id) 61 | if self.pred_yxyx: 62 | # to xyxy 63 | img_dets[:, 0:4] = img_dets[:, [1, 0, 3, 2]] 64 | # to xywh 65 | img_dets[:, 2] -= img_dets[:, 0] 66 | img_dets[:, 3] -= img_dets[:, 1] 67 | for det in img_dets: 68 | score = float(det[4]) 69 | if score < .001: # stop when below this threshold, scores in descending order 70 | break 71 | coco_det = dict( 72 | image_id=int(img_id), 73 | bbox=det[0:4].tolist(), 74 | score=score, 75 | category_id=int(det[5])) 76 | coco_predictions.append(coco_det) 77 | return coco_predictions, coco_ids 78 | 79 | @abc.abstractmethod 80 | def evaluate(self, output_result_file=''): 81 | pass 82 | 83 | def save(self, result_file): 84 | # save results in coco style, override to save in a alternate form 85 | if not self.distributed or dist.get_rank() == 0: 86 | assert len(self.predictions) 87 | coco_predictions, coco_ids = self._coco_predictions() 88 | json.dump(coco_predictions, open(result_file, 'w'), indent=4) 89 | 90 | 91 | class CocoEvaluator(Evaluator): 92 | 93 | def __init__(self, dataset, distributed=False, pred_yxyx=False): 94 | super().__init__(distributed=distributed, pred_yxyx=pred_yxyx) 95 | self._dataset = dataset.parser 96 | self.coco_api = dataset.parser.coco 97 | 98 | def reset(self): 99 | self.img_indices = [] 100 | self.predictions = [] 101 | 102 | def evaluate(self, output_result_file=''): 103 | if not self.distributed or dist.get_rank() == 0: 104 | assert len(self.predictions) 105 | coco_predictions, coco_ids = self._coco_predictions() 106 | if output_result_file: 107 | json.dump(coco_predictions, open(output_result_file, 'w'), indent=4) 108 | results = self.coco_api.loadRes(output_result_file) 109 | else: 110 | with NamedTemporaryFile(prefix='coco_', suffix='.json', delete=False, mode='w') as tmpfile: 111 | json.dump(coco_predictions, tmpfile, indent=4) 112 | results = self.coco_api.loadRes(tmpfile.name) 113 | try: 114 | os.unlink(tmpfile.name) 115 | except OSError: 116 | pass 117 | coco_eval = COCOeval(self.coco_api, results, 'bbox') 118 | coco_eval.params.imgIds = coco_ids # score only ids we've used 119 | coco_eval.evaluate() 120 | coco_eval.accumulate() 121 | coco_eval.summarize() 122 | metric = coco_eval.stats[0] # mAP 0.5-0.95 123 | if self.distributed: 124 | dist.broadcast(torch.tensor(metric, device=self.distributed_device), 0) 125 | else: 126 | metric = torch.tensor(0, device=self.distributed_device) 127 | dist.broadcast(metric, 0) 128 | metric = metric.item() 129 | self.reset() 130 | return metric 131 | 132 | 133 | class TfmEvaluator(Evaluator): 134 | """ Tensorflow Models Evaluator Wrapper """ 135 | def __init__( 136 | self, dataset, distributed=False, pred_yxyx=False, evaluator_cls=tfm_eval.ObjectDetectionEvaluator): 137 | super().__init__(distributed=distributed, pred_yxyx=pred_yxyx) 138 | self._evaluator = evaluator_cls(categories=dataset.parser.cat_dicts) 139 | self._eval_metric_name = self._evaluator._metric_names[0] 140 | self._dataset = dataset.parser 141 | 142 | def reset(self): 143 | self._evaluator.clear() 144 | self.img_indices = [] 145 | self.predictions = [] 146 | 147 | def evaluate(self, output_result_file=''): 148 | if not self.distributed or dist.get_rank() == 0: 149 | for img_idx, img_dets in zip(self.img_indices, self.predictions): 150 | gt = self._dataset.get_ann_info(img_idx) 151 | self._evaluator.add_single_ground_truth_image_info(img_idx, gt) 152 | 153 | bbox = img_dets[:, 0:4] if self.pred_yxyx else img_dets[:, [1, 0, 3, 2]] 154 | det = dict(bbox=bbox, score=img_dets[:, 4], cls=img_dets[:, 5]) 155 | self._evaluator.add_single_detected_image_info(img_idx, det) 156 | 157 | metrics = self._evaluator.evaluate() 158 | _logger.info('Metrics:') 159 | for k, v in metrics.items(): 160 | _logger.info(f'{k}: {v}') 161 | map_metric = metrics[self._eval_metric_name] 162 | if self.distributed: 163 | dist.broadcast(torch.tensor(map_metric, device=self.distributed_device), 0) 164 | else: 165 | map_metric = torch.tensor(0, device=self.distributed_device) 166 | wait = dist.broadcast(map_metric, 0, async_op=True) 167 | while not wait.is_completed(): 168 | # wait without spinning the cpu @ 100%, no need for low latency here 169 | time.sleep(0.5) 170 | map_metric = map_metric.item() 171 | if output_result_file: 172 | self.save(output_result_file) 173 | self.reset() 174 | return map_metric 175 | 176 | 177 | class PascalEvaluator(TfmEvaluator): 178 | 179 | def __init__(self, dataset, distributed=False, pred_yxyx=False): 180 | super().__init__( 181 | dataset, distributed=distributed, pred_yxyx=pred_yxyx, evaluator_cls=tfm_eval.PascalDetectionEvaluator) 182 | 183 | 184 | class OpenImagesEvaluator(TfmEvaluator): 185 | 186 | def __init__(self, dataset, distributed=False, pred_yxyx=False): 187 | super().__init__( 188 | dataset, distributed=distributed, pred_yxyx=pred_yxyx, evaluator_cls=tfm_eval.OpenImagesDetectionEvaluator) 189 | 190 | 191 | def create_evaluator(name, dataset, distributed=False, pred_yxyx=False): 192 | # FIXME support OpenImages Challenge2019 metric w/ image level label consideration 193 | if 'coco' in name: 194 | return CocoEvaluator(dataset, distributed=distributed, pred_yxyx=pred_yxyx) 195 | elif 'openimages' in name: 196 | return OpenImagesEvaluator(dataset, distributed=distributed, pred_yxyx=pred_yxyx) 197 | else: 198 | return PascalEvaluator(dataset, distributed=distributed, pred_yxyx=pred_yxyx) 199 | -------------------------------------------------------------------------------- /effdet/factory.py: -------------------------------------------------------------------------------- 1 | from .efficientdet import EfficientDet, HeadNet 2 | from .bench import DetBenchTrain, DetBenchPredict 3 | from .config import get_efficientdet_config 4 | from .helpers import load_pretrained, load_checkpoint 5 | 6 | 7 | def create_model( 8 | model_name, bench_task='', num_classes=None, pretrained=False, 9 | checkpoint_path='', checkpoint_ema=False, **kwargs): 10 | 11 | config = get_efficientdet_config(model_name) 12 | return create_model_from_config( 13 | config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained, 14 | checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs) 15 | 16 | 17 | def create_model_from_config( 18 | config, bench_task='', num_classes=None, pretrained=False, 19 | checkpoint_path='', checkpoint_ema=False, **kwargs): 20 | 21 | pretrained_backbone = kwargs.pop('pretrained_backbone', True) 22 | if pretrained or checkpoint_path: 23 | pretrained_backbone = False # no point in loading backbone weights 24 | 25 | # Config overrides, override some config values via kwargs. 26 | overrides = ( 27 | 'redundant_bias', 'label_smoothing', 'legacy_focal', 'jit_loss', 'soft_nms', 'max_det_per_image', 'image_size') 28 | for ov in overrides: 29 | value = kwargs.pop(ov, None) 30 | if value is not None: 31 | setattr(config, ov, value) 32 | 33 | labeler = kwargs.pop('bench_labeler', False) 34 | 35 | # create the base model 36 | model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs) 37 | 38 | # pretrained weights are always spec'd for original config, load them before we change the model 39 | if pretrained: 40 | load_pretrained(model, config.url) 41 | 42 | # reset model head if num_classes doesn't match configs 43 | if num_classes is not None and num_classes != config.num_classes: 44 | model.reset_head(num_classes=num_classes) 45 | 46 | # load an argument specified training checkpoint 47 | if checkpoint_path: 48 | load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema) 49 | 50 | # wrap model in task specific training/prediction bench if set 51 | if bench_task == 'train': 52 | model = DetBenchTrain(model, create_labeler=labeler) 53 | elif bench_task == 'predict': 54 | model = DetBenchPredict(model) 55 | return model 56 | -------------------------------------------------------------------------------- /effdet/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | from collections import OrderedDict 5 | 6 | from timm.models import load_checkpoint 7 | 8 | try: 9 | from torch.hub import load_state_dict_from_url 10 | except ImportError: 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | 14 | def load_pretrained(model, url, filter_fn=None, strict=True): 15 | if not url: 16 | logging.warning("Pretrained model URL is empty, using random initialization. " 17 | "Did you intend to use a `tf_` variant of the model?") 18 | return 19 | state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') 20 | if filter_fn is not None: 21 | state_dict = filter_fn(state_dict) 22 | model.load_state_dict(state_dict, strict=strict) 23 | -------------------------------------------------------------------------------- /effdet/loss.py: -------------------------------------------------------------------------------- 1 | """ EfficientDet Focal, Huber/Smooth L1 loss fns w/ jit support 2 | 3 | Based on loss fn in Google's automl EfficientDet repository (Apache 2.0 license). 4 | https://github.com/google/automl/tree/master/efficientdet 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from typing import Optional, List, Tuple 13 | 14 | 15 | def focal_loss_legacy(logits, targets, alpha: float, gamma: float, normalizer): 16 | """Compute the focal loss between `logits` and the golden `target` values. 17 | 18 | 'Legacy focal loss matches the loss used in the official Tensorflow impl for initial 19 | model releases and some time after that. It eventually transitioned to the 'New' loss 20 | defined below. 21 | 22 | Focal loss = -(1-pt)^gamma * log(pt) 23 | where pt is the probability of being classified to the true class. 24 | 25 | Args: 26 | logits: A float32 tensor of size [batch, height_in, width_in, num_predictions]. 27 | 28 | targets: A float32 tensor of size [batch, height_in, width_in, num_predictions]. 29 | 30 | alpha: A float32 scalar multiplying alpha to the loss from positive examples 31 | and (1-alpha) to the loss from negative examples. 32 | 33 | gamma: A float32 scalar modulating loss from hard and easy examples. 34 | 35 | normalizer: A float32 scalar normalizes the total loss from all examples. 36 | 37 | Returns: 38 | loss: A float32 scalar representing normalized total loss. 39 | """ 40 | positive_label_mask = targets == 1.0 41 | cross_entropy = F.binary_cross_entropy_with_logits(logits, targets.to(logits.dtype), reduction='none') 42 | neg_logits = -1.0 * logits 43 | modulator = torch.exp(gamma * targets * neg_logits - gamma * torch.log1p(torch.exp(neg_logits))) 44 | 45 | loss = modulator * cross_entropy 46 | weighted_loss = torch.where(positive_label_mask, alpha * loss, (1.0 - alpha) * loss) 47 | return weighted_loss / normalizer 48 | 49 | 50 | def new_focal_loss(logits, targets, alpha: float, gamma: float, normalizer, label_smoothing: float = 0.01): 51 | """Compute the focal loss between `logits` and the golden `target` values. 52 | 53 | 'New' is not the best descriptor, but this focal loss impl matches recent versions of 54 | the official Tensorflow impl of EfficientDet. It has support for label smoothing, however 55 | it is a bit slower, doesn't jit optimize well, and uses more memory. 56 | 57 | Focal loss = -(1-pt)^gamma * log(pt) 58 | where pt is the probability of being classified to the true class. 59 | Args: 60 | logits: A float32 tensor of size [batch, height_in, width_in, num_predictions]. 61 | targets: A float32 tensor of size [batch, height_in, width_in, num_predictions]. 62 | alpha: A float32 scalar multiplying alpha to the loss from positive examples 63 | and (1-alpha) to the loss from negative examples. 64 | gamma: A float32 scalar modulating loss from hard and easy examples. 65 | normalizer: Divide loss by this value. 66 | label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. 67 | Returns: 68 | loss: A float32 scalar representing normalized total loss. 69 | """ 70 | # compute focal loss multipliers before label smoothing, such that it will not blow up the loss. 71 | pred_prob = logits.sigmoid() 72 | targets = targets.to(logits.dtype) 73 | onem_targets = 1. - targets 74 | p_t = (targets * pred_prob) + (onem_targets * (1. - pred_prob)) 75 | alpha_factor = targets * alpha + onem_targets * (1. - alpha) 76 | modulating_factor = (1. - p_t) ** gamma 77 | 78 | # apply label smoothing for cross_entropy for each entry. 79 | if label_smoothing > 0.: 80 | targets = targets * (1. - label_smoothing) + .5 * label_smoothing 81 | ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') 82 | 83 | # compute the final loss and return 84 | return (1 / normalizer) * alpha_factor * modulating_factor * ce 85 | 86 | 87 | def huber_loss( 88 | input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True): 89 | """ 90 | """ 91 | err = input - target 92 | abs_err = err.abs() 93 | quadratic = torch.clamp(abs_err, max=delta) 94 | linear = abs_err - quadratic 95 | loss = 0.5 * quadratic.pow(2) + delta * linear 96 | if weights is not None: 97 | loss = loss.mul(weights) 98 | if size_average: 99 | return loss.mean() 100 | else: 101 | return loss.sum() 102 | 103 | 104 | def smooth_l1_loss( 105 | input, target, beta: float = 1. / 9, weights: Optional[torch.Tensor] = None, size_average: bool = True): 106 | """ 107 | very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter 108 | """ 109 | if beta < 1e-5: 110 | # if beta == 0, then torch.where will result in nan gradients when 111 | # the chain rule is applied due to pytorch implementation details 112 | # (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of 113 | # zeros, rather than "no gradient"). To avoid this issue, we define 114 | # small values of beta to be exactly l1 loss. 115 | loss = torch.abs(input - target) 116 | else: 117 | err = torch.abs(input - target) 118 | loss = torch.where(err < beta, 0.5 * err.pow(2) / beta, err - 0.5 * beta) 119 | if weights is not None: 120 | loss *= weights 121 | if size_average: 122 | return loss.mean() 123 | else: 124 | return loss.sum() 125 | 126 | 127 | def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1): 128 | """Computes box regression loss.""" 129 | # delta is typically around the mean value of regression target. 130 | # for instances, the regression targets of 512x512 input with 6 anchors on 131 | # P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2]. 132 | normalizer = num_positives * 4.0 133 | mask = box_targets != 0.0 134 | box_loss = huber_loss(box_outputs, box_targets, weights=mask, delta=delta, size_average=False) 135 | return box_loss / normalizer 136 | 137 | 138 | def one_hot(x, num_classes: int): 139 | # NOTE: PyTorch one-hot does not handle -ve entries (no hot) like Tensorflow, so mask them out 140 | x_non_neg = (x >= 0).unsqueeze(-1) 141 | onehot = torch.zeros(x.shape + (num_classes,), device=x.device, dtype=torch.float32) 142 | return onehot.scatter(-1, x.unsqueeze(-1) * x_non_neg, 1) * x_non_neg 143 | 144 | 145 | def loss_fn( 146 | cls_outputs: List[torch.Tensor], 147 | box_outputs: List[torch.Tensor], 148 | cls_targets: List[torch.Tensor], 149 | box_targets: List[torch.Tensor], 150 | num_positives: torch.Tensor, 151 | num_classes: int, 152 | alpha: float, 153 | gamma: float, 154 | delta: float, 155 | box_loss_weight: float, 156 | label_smoothing: float = 0., 157 | legacy_focal: bool = False, 158 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 159 | """Computes total detection loss. 160 | Computes total detection loss including box and class loss from all levels. 161 | Args: 162 | cls_outputs: a List with values representing logits in [batch_size, height, width, num_anchors]. 163 | at each feature level (index) 164 | 165 | box_outputs: a List with values representing box regression targets in 166 | [batch_size, height, width, num_anchors * 4] at each feature level (index) 167 | 168 | cls_targets: groundtruth class targets. 169 | 170 | box_targets: groundtrusth box targets. 171 | 172 | num_positives: num positive grountruth anchors 173 | 174 | Returns: 175 | total_loss: an integer tensor representing total loss reducing from class and box losses from all levels. 176 | 177 | cls_loss: an integer tensor representing total class loss. 178 | 179 | box_loss: an integer tensor representing total box regression loss. 180 | """ 181 | # Sum all positives in a batch for normalization and avoid zero 182 | # num_positives_sum, which would lead to inf loss during training 183 | num_positives_sum = (num_positives.sum() + 1.0).float() 184 | levels = len(cls_outputs) 185 | 186 | cls_losses = [] 187 | box_losses = [] 188 | for l in range(levels): 189 | cls_targets_at_level = cls_targets[l] 190 | box_targets_at_level = box_targets[l] 191 | 192 | # Onehot encoding for classification labels. 193 | cls_targets_at_level_oh = one_hot(cls_targets_at_level, num_classes) 194 | 195 | bs, height, width, _, _ = cls_targets_at_level_oh.shape 196 | cls_targets_at_level_oh = cls_targets_at_level_oh.view(bs, height, width, -1) 197 | cls_outputs_at_level = cls_outputs[l].permute(0, 2, 3, 1).float() 198 | if legacy_focal: 199 | cls_loss = focal_loss_legacy( 200 | cls_outputs_at_level, cls_targets_at_level_oh, 201 | alpha=alpha, gamma=gamma, normalizer=num_positives_sum) 202 | else: 203 | cls_loss = new_focal_loss( 204 | cls_outputs_at_level, cls_targets_at_level_oh, 205 | alpha=alpha, gamma=gamma, normalizer=num_positives_sum, label_smoothing=label_smoothing) 206 | cls_loss = cls_loss.view(bs, height, width, -1, num_classes) 207 | cls_loss = cls_loss * (cls_targets_at_level != -2).unsqueeze(-1) 208 | cls_losses.append(cls_loss.sum()) # FIXME reference code added a clamp here at some point ...clamp(0, 2)) 209 | 210 | box_losses.append(_box_loss( 211 | box_outputs[l].permute(0, 2, 3, 1).float(), 212 | box_targets_at_level, 213 | num_positives_sum, 214 | delta=delta)) 215 | 216 | # Sum per level losses to total loss. 217 | cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1) 218 | box_loss = torch.sum(torch.stack(box_losses, dim=-1), dim=-1) 219 | total_loss = cls_loss + box_loss_weight * box_loss 220 | return total_loss, cls_loss, box_loss 221 | 222 | 223 | loss_jit = torch.jit.script(loss_fn) 224 | 225 | 226 | class DetectionLoss(nn.Module): 227 | 228 | __constants__ = ['num_classes'] 229 | 230 | def __init__(self, config): 231 | super(DetectionLoss, self).__init__() 232 | self.config = config 233 | self.num_classes = config.num_classes 234 | self.alpha = config.alpha 235 | self.gamma = config.gamma 236 | self.delta = config.delta 237 | self.box_loss_weight = config.box_loss_weight 238 | self.label_smoothing = config.label_smoothing 239 | self.legacy_focal = config.legacy_focal 240 | self.use_jit = config.jit_loss 241 | 242 | def forward( 243 | self, 244 | cls_outputs: List[torch.Tensor], 245 | box_outputs: List[torch.Tensor], 246 | cls_targets: List[torch.Tensor], 247 | box_targets: List[torch.Tensor], 248 | num_positives: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 249 | 250 | l_fn = loss_fn 251 | if not torch.jit.is_scripting() and self.use_jit: 252 | # This branch only active if parent / bench itself isn't being scripted 253 | # NOTE: I haven't figured out what to do here wrt to tracing, is it an issue? 254 | l_fn = loss_jit 255 | 256 | return l_fn( 257 | cls_outputs, box_outputs, cls_targets, box_targets, num_positives, 258 | num_classes=self.num_classes, alpha=self.alpha, gamma=self.gamma, delta=self.delta, 259 | box_loss_weight=self.box_loss_weight, label_smoothing=self.label_smoothing, legacy_focal=self.legacy_focal) 260 | -------------------------------------------------------------------------------- /effdet/object_detection/README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Object Detection 2 | 3 | All of this code is adapted/ported/copied from https://github.com/google/automl/tree/552d0facd14f4fe9205a67fb13ecb5690a4d1c94/efficientdet/object_detection -------------------------------------------------------------------------------- /effdet/object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Object detection data loaders and libraries are mostly based on RetinaNet: 16 | # https://github.com/tensorflow/tpu/tree/master/models/official/retinanet 17 | from .argmax_matcher import ArgMaxMatcher 18 | from .box_coder import FasterRcnnBoxCoder 19 | from .box_list import BoxList 20 | from .matcher import Match 21 | from .region_similarity_calculator import IouSimilarity 22 | from .target_assigner import TargetAssigner 23 | -------------------------------------------------------------------------------- /effdet/object_detection/argmax_matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Argmax matcher implementation. 16 | 17 | This class takes a similarity matrix and matches columns to rows based on the 18 | maximum value per column. One can specify matched_thresholds and 19 | to prevent columns from matching to rows (generally resulting in a negative 20 | training example) and unmatched_theshold to ignore the match (generally 21 | resulting in neither a positive or negative training example). 22 | 23 | This matcher is used in Fast(er)-RCNN. 24 | 25 | Note: matchers are used in TargetAssigners. There is a create_target_assigner 26 | factory function for popular implementations. 27 | """ 28 | import torch 29 | from .matcher import Match 30 | from typing import Optional 31 | 32 | 33 | def one_hot_bool(x, num_classes: int): 34 | # for improved perf over PyTorch builtin one_hot, scatter to bool 35 | onehot = torch.zeros(x.size(0), num_classes, device=x.device, dtype=torch.bool) 36 | return onehot.scatter_(1, x.unsqueeze(1), 1) 37 | 38 | 39 | @torch.jit.script 40 | class ArgMaxMatcher(object): # cannot inherit with torchscript 41 | """Matcher based on highest value. 42 | 43 | This class computes matches from a similarity matrix. Each column is matched 44 | to a single row. 45 | 46 | To support object detection target assignment this class enables setting both 47 | matched_threshold (upper threshold) and unmatched_threshold (lower thresholds) 48 | defining three categories of similarity which define whether examples are 49 | positive, negative, or ignored: 50 | (1) similarity >= matched_threshold: Highest similarity. Matched/Positive! 51 | (2) matched_threshold > similarity >= unmatched_threshold: Medium similarity. 52 | Depending on negatives_lower_than_unmatched, this is either 53 | Unmatched/Negative OR Ignore. 54 | (3) unmatched_threshold > similarity: Lowest similarity. Depending on flag 55 | negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore. 56 | For ignored matches this class sets the values in the Match object to -2. 57 | """ 58 | 59 | def __init__(self, 60 | matched_threshold: float, 61 | unmatched_threshold: Optional[float] = None, 62 | negatives_lower_than_unmatched: bool = True, 63 | force_match_for_each_row: bool = False): 64 | """Construct ArgMaxMatcher. 65 | 66 | Args: 67 | matched_threshold: Threshold for positive matches. Positive if 68 | sim >= matched_threshold, where sim is the maximum value of the 69 | similarity matrix for a given column. Set to None for no threshold. 70 | unmatched_threshold: Threshold for negative matches. Negative if 71 | sim < unmatched_threshold. Defaults to matched_threshold 72 | when set to None. 73 | negatives_lower_than_unmatched: Boolean which defaults to True. If True 74 | then negative matches are the ones below the unmatched_threshold, 75 | whereas ignored matches are in between the matched and unmatched 76 | threshold. If False, then negative matches are in between the matched 77 | and unmatched threshold, and everything lower than unmatched is ignored. 78 | force_match_for_each_row: If True, ensures that each row is matched to 79 | at least one column (which is not guaranteed otherwise if the 80 | matched_threshold is high). Defaults to False. See 81 | argmax_matcher_test.testMatcherForceMatch() for an example. 82 | 83 | Raises: 84 | ValueError: if unmatched_threshold is set but matched_threshold is not set 85 | or if unmatched_threshold > matched_threshold. 86 | """ 87 | if (matched_threshold is None) and (unmatched_threshold is not None): 88 | raise ValueError('Need to also define matched_threshold when unmatched_threshold is defined') 89 | self._matched_threshold = matched_threshold 90 | self._unmatched_threshold: float = 0. 91 | if unmatched_threshold is None: 92 | self._unmatched_threshold = matched_threshold 93 | else: 94 | if unmatched_threshold > matched_threshold: 95 | raise ValueError('unmatched_threshold needs to be smaller or equal to matched_threshold') 96 | self._unmatched_threshold = unmatched_threshold 97 | if not negatives_lower_than_unmatched: 98 | if self._unmatched_threshold == self._matched_threshold: 99 | raise ValueError('When negatives are in between matched and unmatched thresholds, these ' 100 | 'cannot be of equal value. matched: %s, unmatched: %s', 101 | self._matched_threshold, self._unmatched_threshold) 102 | self._force_match_for_each_row = force_match_for_each_row 103 | self._negatives_lower_than_unmatched = negatives_lower_than_unmatched 104 | 105 | def _match_when_rows_are_empty(self, similarity_matrix): 106 | """Performs matching when the rows of similarity matrix are empty. 107 | 108 | When the rows are empty, all detections are false positives. So we return 109 | a tensor of -1's to indicate that the columns do not match to any rows. 110 | 111 | Returns: 112 | matches: int32 tensor indicating the row each column matches to. 113 | """ 114 | return -1 * torch.ones(similarity_matrix.shape[1], dtype=torch.long, device=similarity_matrix.device) 115 | 116 | def _match_when_rows_are_non_empty(self, similarity_matrix): 117 | """Performs matching when the rows of similarity matrix are non empty. 118 | 119 | Returns: 120 | matches: int32 tensor indicating the row each column matches to. 121 | """ 122 | # Matches for each column 123 | matched_vals, matches = torch.max(similarity_matrix, 0) 124 | 125 | # Deal with matched and unmatched threshold 126 | if self._matched_threshold is not None: 127 | # Get logical indices of ignored and unmatched columns as tf.int64 128 | below_unmatched_threshold = self._unmatched_threshold > matched_vals 129 | between_thresholds = (matched_vals >= self._unmatched_threshold) & \ 130 | (self._matched_threshold > matched_vals) 131 | 132 | if self._negatives_lower_than_unmatched: 133 | matches = self._set_values_using_indicator(matches, below_unmatched_threshold, -1) 134 | matches = self._set_values_using_indicator(matches, between_thresholds, -2) 135 | else: 136 | matches = self._set_values_using_indicator(matches, below_unmatched_threshold, -2) 137 | matches = self._set_values_using_indicator(matches, between_thresholds, -1) 138 | 139 | if self._force_match_for_each_row: 140 | force_match_column_ids = torch.argmax(similarity_matrix, 1) 141 | force_match_column_indicators = one_hot_bool(force_match_column_ids, similarity_matrix.shape[1]) 142 | force_match_column_mask, force_match_row_ids = torch.max(force_match_column_indicators, 0) 143 | final_matches = torch.where(force_match_column_mask, force_match_row_ids, matches) 144 | return final_matches 145 | else: 146 | return matches 147 | 148 | def match(self, similarity_matrix): 149 | """Tries to match each column of the similarity matrix to a row. 150 | 151 | Args: 152 | similarity_matrix: tensor of shape [N, M] representing any similarity metric. 153 | 154 | Returns: 155 | Match object with corresponding matches for each of M columns. 156 | """ 157 | if similarity_matrix.shape[0] == 0: 158 | return Match(self._match_when_rows_are_empty(similarity_matrix)) 159 | else: 160 | return Match(self._match_when_rows_are_non_empty(similarity_matrix)) 161 | 162 | def _set_values_using_indicator(self, x, indicator, val: int): 163 | """Set the indicated fields of x to val. 164 | 165 | Args: 166 | x: tensor. 167 | indicator: boolean with same shape as x. 168 | val: scalar with value to set. 169 | 170 | Returns: 171 | modified tensor. 172 | """ 173 | indicator = indicator.to(dtype=x.dtype) 174 | return x * (1 - indicator) + val * indicator 175 | -------------------------------------------------------------------------------- /effdet/object_detection/box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base box coder. 16 | 17 | Box coders convert between coordinate frames, namely image-centric 18 | (with (0,0) on the top left of image) and anchor-centric (with (0,0) being 19 | defined by a specific anchor). 20 | 21 | Users of a BoxCoder can call two methods: 22 | encode: which encodes a box with respect to a given anchor 23 | (or rather, a tensor of boxes wrt a corresponding tensor of anchors) and 24 | decode: which inverts this encoding with a decode operation. 25 | In both cases, the arguments are assumed to be in 1-1 correspondence already; 26 | it is not the job of a BoxCoder to perform matching. 27 | """ 28 | import torch 29 | from typing import List, Optional 30 | from .box_list import BoxList 31 | 32 | # Box coder types. 33 | FASTER_RCNN = 'faster_rcnn' 34 | KEYPOINT = 'keypoint' 35 | MEAN_STDDEV = 'mean_stddev' 36 | SQUARE = 'square' 37 | 38 | 39 | """Faster RCNN box coder. 40 | 41 | Faster RCNN box coder follows the coding schema described below: 42 | ty = (y - ya) / ha 43 | tx = (x - xa) / wa 44 | th = log(h / ha) 45 | tw = log(w / wa) 46 | where x, y, w, h denote the box's center coordinates, width and height 47 | respectively. Similarly, xa, ya, wa, ha denote the anchor's center 48 | coordinates, width and height. tx, ty, tw and th denote the anchor-encoded 49 | center, width and height respectively. 50 | 51 | See http://arxiv.org/abs/1506.01497 for details. 52 | """ 53 | 54 | 55 | EPS = 1e-7 56 | 57 | 58 | #@torch.jit.script 59 | class FasterRcnnBoxCoder(object): 60 | """Faster RCNN box coder.""" 61 | 62 | def __init__(self, scale_factors: Optional[List[float]] = None, eps: float = EPS): 63 | """Constructor for FasterRcnnBoxCoder. 64 | 65 | Args: 66 | scale_factors: List of 4 positive scalars to scale ty, tx, th and tw. 67 | If set to None, does not perform scaling. For Faster RCNN, 68 | the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0]. 69 | """ 70 | self._scale_factors = scale_factors 71 | if scale_factors is not None: 72 | assert len(scale_factors) == 4 73 | for scalar in scale_factors: 74 | assert scalar > 0 75 | self.eps = eps 76 | 77 | #@property 78 | def code_size(self): 79 | return 4 80 | 81 | def encode(self, boxes: BoxList, anchors: BoxList): 82 | """Encode a box collection with respect to anchor collection. 83 | 84 | Args: 85 | boxes: BoxList holding N boxes to be encoded. 86 | anchors: BoxList of anchors. 87 | 88 | Returns: 89 | a tensor representing N anchor-encoded boxes of the format [ty, tx, th, tw]. 90 | """ 91 | # Convert anchors to the center coordinate representation. 92 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes() 93 | ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes() 94 | # Avoid NaN in division and log below. 95 | ha += self.eps 96 | wa += self.eps 97 | h += self.eps 98 | w += self.eps 99 | 100 | tx = (xcenter - xcenter_a) / wa 101 | ty = (ycenter - ycenter_a) / ha 102 | tw = torch.log(w / wa) 103 | th = torch.log(h / ha) 104 | # Scales location targets as used in paper for joint training. 105 | if self._scale_factors is not None: 106 | ty *= self._scale_factors[0] 107 | tx *= self._scale_factors[1] 108 | th *= self._scale_factors[2] 109 | tw *= self._scale_factors[3] 110 | return torch.stack([ty, tx, th, tw]).t() 111 | 112 | def decode(self, rel_codes, anchors: BoxList): 113 | """Decode relative codes to boxes. 114 | 115 | Args: 116 | rel_codes: a tensor representing N anchor-encoded boxes. 117 | anchors: BoxList of anchors. 118 | 119 | Returns: 120 | boxes: BoxList holding N bounding boxes. 121 | """ 122 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes() 123 | 124 | ty, tx, th, tw = rel_codes.t().unbind() 125 | if self._scale_factors is not None: 126 | ty /= self._scale_factors[0] 127 | tx /= self._scale_factors[1] 128 | th /= self._scale_factors[2] 129 | tw /= self._scale_factors[3] 130 | w = torch.exp(tw) * wa 131 | h = torch.exp(th) * ha 132 | ycenter = ty * ha + ycenter_a 133 | xcenter = tx * wa + xcenter_a 134 | ymin = ycenter - h / 2. 135 | xmin = xcenter - w / 2. 136 | ymax = ycenter + h / 2. 137 | xmax = xcenter + w / 2. 138 | return BoxList(torch.stack([ymin, xmin, ymax, xmax]).t()) 139 | 140 | 141 | def batch_decode(encoded_boxes, box_coder: FasterRcnnBoxCoder, anchors: BoxList): 142 | """Decode a batch of encoded boxes. 143 | 144 | This op takes a batch of encoded bounding boxes and transforms 145 | them to a batch of bounding boxes specified by their corners in 146 | the order of [y_min, x_min, y_max, x_max]. 147 | 148 | Args: 149 | encoded_boxes: a float32 tensor of shape [batch_size, num_anchors, 150 | code_size] representing the location of the objects. 151 | box_coder: a BoxCoder object. 152 | anchors: a BoxList of anchors used to encode `encoded_boxes`. 153 | 154 | Returns: 155 | decoded_boxes: a float32 tensor of shape [batch_size, num_anchors, coder_size] 156 | representing the corners of the objects in the order of [y_min, x_min, y_max, x_max]. 157 | 158 | Raises: 159 | ValueError: if batch sizes of the inputs are inconsistent, or if 160 | the number of anchors inferred from encoded_boxes and anchors are inconsistent. 161 | """ 162 | assert len(encoded_boxes.shape) == 3 163 | if encoded_boxes.shape[1] != anchors.num_boxes(): 164 | raise ValueError('The number of anchors inferred from encoded_boxes' 165 | ' and anchors are inconsistent: shape[1] of encoded_boxes' 166 | ' %s should be equal to the number of anchors: %s.' % 167 | (encoded_boxes.shape[1], anchors.num_boxes())) 168 | 169 | decoded_boxes = torch.stack([ 170 | box_coder.decode(boxes, anchors).boxes for boxes in encoded_boxes.unbind() 171 | ]) 172 | return decoded_boxes 173 | -------------------------------------------------------------------------------- /effdet/object_detection/box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Bounding Box List definition. 16 | 17 | BoxList represents a list of bounding boxes as tensorflow 18 | tensors, where each bounding box is represented as a row of 4 numbers, 19 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes 20 | within a given list correspond to a single image. See also 21 | box_list.py for common box related operations (such as area, iou, etc). 22 | 23 | Optionally, users can add additional related fields (such as weights). 24 | We assume the following things to be true about fields: 25 | * they correspond to boxes in the box_list along the 0th dimension 26 | * they have inferable rank at graph construction time 27 | * all dimensions except for possibly the 0th can be inferred 28 | (i.e., not None) at graph construction time. 29 | 30 | Some other notes: 31 | * Following tensorflow conventions, we use height, width ordering, 32 | and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering 33 | * Tensors are always provided as (flat) [N, 4] tensors. 34 | """ 35 | import torch 36 | from typing import Optional, List, Dict 37 | 38 | 39 | @torch.jit.script 40 | class BoxList(object): 41 | """Box collection.""" 42 | data: Dict[str, torch.Tensor] 43 | 44 | def __init__(self, boxes): 45 | """Constructs box collection. 46 | 47 | Args: 48 | boxes: a tensor of shape [N, 4] representing box corners 49 | 50 | Raises: 51 | ValueError: if invalid dimensions for bbox data or if bbox data is not in float32 format. 52 | """ 53 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4: 54 | raise ValueError('Invalid dimensions for box data.') 55 | if boxes.dtype != torch.float32: 56 | raise ValueError('Invalid tensor type: should be tf.float32') 57 | self.data = {'boxes': boxes} 58 | 59 | def num_boxes(self): 60 | """Returns number of boxes held in collection. 61 | 62 | Returns: 63 | a tensor representing the number of boxes held in the collection. 64 | """ 65 | return self.data['boxes'].shape[0] 66 | 67 | def get_all_fields(self): 68 | """Returns all fields.""" 69 | return self.data.keys() 70 | 71 | def get_extra_fields(self): 72 | """Returns all non-box fields (i.e., everything not named 'boxes').""" 73 | # return [k for k in self.data.keys() if k != 'boxes'] # FIXME torscript doesn't support comprehensions yet 74 | extra: List[str] = [] 75 | for k in self.data.keys(): 76 | if k != 'boxes': 77 | extra.append(k) 78 | return extra 79 | 80 | def add_field(self, field: str, field_data: torch.Tensor): 81 | """Add field to box list. 82 | 83 | This method can be used to add related box data such as weights/labels, etc. 84 | 85 | Args: 86 | field: a string key to access the data via `get` 87 | field_data: a tensor containing the data to store in the BoxList 88 | """ 89 | self.data[field] = field_data 90 | 91 | def has_field(self, field: str): 92 | return field in self.data 93 | 94 | #@property # FIXME for torchscript compat 95 | def boxes(self): 96 | """Convenience function for accessing box coordinates. 97 | 98 | Returns: 99 | a tensor with shape [N, 4] representing box coordinates. 100 | """ 101 | return self.get_field('boxes') 102 | 103 | #@boxes.setter # FIXME for torchscript compat 104 | def set_boxes(self, boxes): 105 | """Convenience function for setting box coordinates. 106 | 107 | Args: 108 | boxes: a tensor of shape [N, 4] representing box corners 109 | 110 | Raises: 111 | ValueError: if invalid dimensions for bbox data 112 | """ 113 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4: 114 | raise ValueError('Invalid dimensions for box data.') 115 | self.data['boxes'] = boxes 116 | 117 | def get_field(self, field: str): 118 | """Accesses a box collection and associated fields. 119 | 120 | This function returns specified field with object; if no field is specified, 121 | it returns the box coordinates. 122 | 123 | Args: 124 | field: this optional string parameter can be used to specify a related field to be accessed. 125 | 126 | Returns: 127 | a tensor representing the box collection or an associated field. 128 | 129 | Raises: 130 | ValueError: if invalid field 131 | """ 132 | if not self.has_field(field): 133 | raise ValueError(f'field {field} does not exist') 134 | return self.data[field] 135 | 136 | def set_field(self, field: str, value: torch.Tensor): 137 | """Sets the value of a field. 138 | 139 | Updates the field of a box_list with a given value. 140 | 141 | Args: 142 | field: (string) name of the field to set value. 143 | value: the value to assign to the field. 144 | 145 | Raises: 146 | ValueError: if the box_list does not have specified field. 147 | """ 148 | if not self.has_field(field): 149 | raise ValueError(f'field {field} does not exist') 150 | self.data[field] = value 151 | 152 | def get_center_coordinates_and_sizes(self): 153 | """Computes the center coordinates, height and width of the boxes. 154 | 155 | Returns: 156 | a list of 4 1-D tensors [ycenter, xcenter, height, width]. 157 | """ 158 | box_corners = self.boxes() 159 | ymin, xmin, ymax, xmax = box_corners.t().unbind() 160 | width = xmax - xmin 161 | height = ymax - ymin 162 | ycenter = ymin + height / 2. 163 | xcenter = xmin + width / 2. 164 | return [ycenter, xcenter, height, width] 165 | 166 | def transpose_coordinates(self): 167 | """Transpose the coordinate representation in a boxlist. 168 | 169 | """ 170 | y_min, x_min, y_max, x_max = self.boxes().chunk(4, dim=1) 171 | self.set_boxes(torch.cat([x_min, y_min, x_max, y_max], 1)) 172 | 173 | def as_tensor_dict(self, fields: Optional[List[str]] = None): 174 | """Retrieves specified fields as a dictionary of tensors. 175 | 176 | Args: 177 | fields: (optional) list of fields to return in the dictionary. 178 | If None (default), all fields are returned. 179 | 180 | Returns: 181 | tensor_dict: A dictionary of tensors specified by fields. 182 | 183 | Raises: 184 | ValueError: if specified field is not contained in boxlist. 185 | """ 186 | tensor_dict = {} 187 | if fields is None: 188 | fields = self.get_all_fields() 189 | for field in fields: 190 | if not self.has_field(field): 191 | raise ValueError('boxlist must contain all specified fields') 192 | tensor_dict[field] = self.get_field(field) 193 | return tensor_dict 194 | 195 | #@property 196 | def device(self): 197 | return self.data['boxes'].device 198 | -------------------------------------------------------------------------------- /effdet/object_detection/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Matcher interface and Match class. 16 | 17 | This module defines the Matcher interface and the Match object. The job of the 18 | matcher is to match row and column indices based on the similarity matrix and 19 | other optional parameters. Each column is matched to at most one row. There 20 | are three possibilities for the matching: 21 | 22 | 1) match: A column matches a row. 23 | 2) no_match: A column does not match any row. 24 | 3) ignore: A column that is neither 'match' nor no_match. 25 | 26 | The ignore case is regularly encountered in object detection: when an anchor has 27 | a relatively small overlap with a ground-truth box, one neither wants to 28 | consider this box a positive example (match) nor a negative example (no match). 29 | 30 | The Match class is used to store the match results and it provides simple apis 31 | to query the results. 32 | """ 33 | import torch 34 | 35 | 36 | @torch.jit.script 37 | class Match(object): 38 | """Class to store results from the matcher. 39 | 40 | This class is used to store the results from the matcher. It provides 41 | convenient methods to query the matching results. 42 | """ 43 | 44 | def __init__(self, match_results: torch.Tensor): 45 | """Constructs a Match object. 46 | 47 | Args: 48 | match_results: Integer tensor of shape [N] with (1) match_results[i]>=0, 49 | meaning that column i is matched with row match_results[i]. 50 | (2) match_results[i]=-1, meaning that column i is not matched. 51 | (3) match_results[i]=-2, meaning that column i is ignored. 52 | 53 | Raises: 54 | ValueError: if match_results does not have rank 1 or is not an integer int32 scalar tensor 55 | """ 56 | if len(match_results.shape) != 1: 57 | raise ValueError('match_results should have rank 1') 58 | if match_results.dtype not in (torch.int32, torch.int64): 59 | raise ValueError('match_results should be an int32 or int64 scalar tensor') 60 | self.match_results = match_results 61 | 62 | def matched_column_indices(self): 63 | """Returns column indices that match to some row. 64 | 65 | The indices returned by this op are always sorted in increasing order. 66 | 67 | Returns: 68 | column_indices: int32 tensor of shape [K] with column indices. 69 | """ 70 | return torch.nonzero(self.match_results > -1).flatten().long() 71 | 72 | def matched_column_indicator(self): 73 | """Returns column indices that are matched. 74 | 75 | Returns: 76 | column_indices: int32 tensor of shape [K] with column indices. 77 | """ 78 | return self.match_results >= 0 79 | 80 | def num_matched_columns(self): 81 | """Returns number (int32 scalar tensor) of matched columns.""" 82 | return self.matched_column_indices().numel() 83 | 84 | def unmatched_column_indices(self): 85 | """Returns column indices that do not match any row. 86 | 87 | The indices returned by this op are always sorted in increasing order. 88 | 89 | Returns: 90 | column_indices: int32 tensor of shape [K] with column indices. 91 | """ 92 | return torch.nonzero(self.match_results == -1).flatten().long() 93 | 94 | def unmatched_column_indicator(self): 95 | """Returns column indices that are unmatched. 96 | 97 | Returns: 98 | column_indices: int32 tensor of shape [K] with column indices. 99 | """ 100 | return self.match_results == -1 101 | 102 | def num_unmatched_columns(self): 103 | """Returns number (int32 scalar tensor) of unmatched columns.""" 104 | return self.unmatched_column_indices().numel() 105 | 106 | def ignored_column_indices(self): 107 | """Returns column indices that are ignored (neither Matched nor Unmatched). 108 | 109 | The indices returned by this op are always sorted in increasing order. 110 | 111 | Returns: 112 | column_indices: int32 tensor of shape [K] with column indices. 113 | """ 114 | return torch.nonzero(self.ignored_column_indicator()).flatten().long() 115 | 116 | def ignored_column_indicator(self): 117 | """Returns boolean column indicator where True means the column is ignored. 118 | 119 | Returns: 120 | column_indicator: boolean vector which is True for all ignored column indices. 121 | """ 122 | return self.match_results == -2 123 | 124 | def num_ignored_columns(self): 125 | """Returns number (int32 scalar tensor) of matched columns.""" 126 | return self.ignored_column_indices().numel() 127 | 128 | def unmatched_or_ignored_column_indices(self): 129 | """Returns column indices that are unmatched or ignored. 130 | 131 | The indices returned by this op are always sorted in increasing order. 132 | 133 | Returns: 134 | column_indices: int32 tensor of shape [K] with column indices. 135 | """ 136 | return torch.nonzero(0 > self.match_results).flatten().long() 137 | 138 | def matched_row_indices(self): 139 | """Returns row indices that match some column. 140 | 141 | The indices returned by this op are ordered so as to be in correspondence with the output of 142 | matched_column_indicator(). For example if self.matched_column_indicator() is [0,2], 143 | and self.matched_row_indices() is [7, 3], then we know that column 0 was matched to row 7 and 144 | column 2 was matched to row 3. 145 | 146 | Returns: 147 | row_indices: int32 tensor of shape [K] with row indices. 148 | """ 149 | return torch.gather(self.match_results, 0, self.matched_column_indices()).flatten().long() 150 | 151 | def gather_based_on_match(self, input_tensor, unmatched_value, ignored_value): 152 | """Gathers elements from `input_tensor` based on match results. 153 | 154 | For columns that are matched to a row, gathered_tensor[col] is set to input_tensor[match_results[col]]. 155 | For columns that are unmatched, gathered_tensor[col] is set to unmatched_value. Finally, for columns that 156 | are ignored gathered_tensor[col] is set to ignored_value. 157 | 158 | Note that the input_tensor.shape[1:] must match with unmatched_value.shape 159 | and ignored_value.shape 160 | 161 | Args: 162 | input_tensor: Tensor to gather values from. 163 | unmatched_value: Constant tensor or python scalar value for unmatched columns. 164 | ignored_value: Constant tensor or python scalar for ignored columns. 165 | 166 | Returns: 167 | gathered_tensor: A tensor containing values gathered from input_tensor. 168 | The shape of the gathered tensor is [match_results.shape[0]] + input_tensor.shape[1:]. 169 | """ 170 | if isinstance(ignored_value, torch.Tensor): 171 | input_tensor = torch.cat([ignored_value, unmatched_value, input_tensor], dim=0) 172 | else: 173 | # scalars 174 | input_tensor = torch.cat([ 175 | torch.tensor([ignored_value, unmatched_value], dtype=input_tensor.dtype, device=input_tensor.device), 176 | input_tensor], dim=0) 177 | gather_indices = torch.clamp(self.match_results + 2, min=0) 178 | gathered_tensor = torch.index_select(input_tensor, 0, gather_indices) 179 | return gathered_tensor 180 | -------------------------------------------------------------------------------- /effdet/object_detection/region_similarity_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Region Similarity Calculators for BoxLists. 16 | 17 | Region Similarity Calculators compare a pairwise measure of similarity 18 | between the boxes in two BoxLists. 19 | """ 20 | import torch 21 | from .box_list import BoxList 22 | 23 | 24 | def area(boxlist: BoxList): 25 | """Computes area of boxes. 26 | 27 | Args: 28 | boxlist: BoxList holding N boxes 29 | 30 | Returns: 31 | a tensor with shape [N] representing box areas. 32 | """ 33 | y_min, x_min, y_max, x_max = boxlist.boxes().chunk(4, dim=1) 34 | out = (y_max - y_min).squeeze(1) * (x_max - x_min).squeeze(1) 35 | return out 36 | 37 | 38 | def intersection(boxlist1: BoxList, boxlist2: BoxList): 39 | """Compute pairwise intersection areas between boxes. 40 | 41 | Args: 42 | boxlist1: BoxList holding N boxes 43 | boxlist2: BoxList holding M boxes 44 | 45 | Returns: 46 | a tensor with shape [N, M] representing pairwise intersections 47 | """ 48 | y_min1, x_min1, y_max1, x_max1 = boxlist1.boxes().chunk(4, dim=1) 49 | y_min2, x_min2, y_max2, x_max2 = boxlist2.boxes().chunk(4, dim=1) 50 | all_pairs_min_ymax = torch.min(y_max1, y_max2.t()) 51 | all_pairs_max_ymin = torch.max(y_min1, y_min2.t()) 52 | intersect_heights = torch.clamp(all_pairs_min_ymax - all_pairs_max_ymin, min=0) 53 | all_pairs_min_xmax = torch.min(x_max1, x_max2.t()) 54 | all_pairs_max_xmin = torch.max(x_min1, x_min2.t()) 55 | intersect_widths = torch.clamp(all_pairs_min_xmax - all_pairs_max_xmin, min=0) 56 | return intersect_heights * intersect_widths 57 | 58 | 59 | def iou(boxlist1: BoxList, boxlist2: BoxList): 60 | """Computes pairwise intersection-over-union between box collections. 61 | 62 | Args: 63 | boxlist1: BoxList holding N boxes 64 | boxlist2: BoxList holding M boxes 65 | 66 | Returns: 67 | a tensor with shape [N, M] representing pairwise iou scores. 68 | """ 69 | intersections = intersection(boxlist1, boxlist2) 70 | areas1 = area(boxlist1) 71 | areas2 = area(boxlist2) 72 | unions = areas1.unsqueeze(1) + areas2.unsqueeze(0) - intersections 73 | return torch.where(intersections == 0.0, torch.zeros_like(intersections), intersections / unions) 74 | 75 | 76 | @torch.jit.script 77 | class IouSimilarity(object): 78 | """Class to compute similarity based on Intersection over Union (IOU) metric. 79 | 80 | This class computes pairwise similarity between two BoxLists based on IOU. 81 | """ 82 | def __init__(self): 83 | pass 84 | 85 | def compare(self, boxlist1: BoxList, boxlist2: BoxList): 86 | """Computes matrix of pairwise similarity between BoxLists. 87 | 88 | This op (to be overridden) computes a measure of pairwise similarity between 89 | the boxes in the given BoxLists. Higher values indicate more similarity. 90 | 91 | Note that this method simply measures similarity and does not explicitly 92 | perform a matching. 93 | 94 | Args: 95 | boxlist1: BoxList holding N boxes. 96 | boxlist2: BoxList holding M boxes. 97 | 98 | Returns: 99 | a (float32) tensor of shape [N, M] with pairwise similarity score. 100 | """ 101 | return iou(boxlist1, boxlist2) 102 | -------------------------------------------------------------------------------- /effdet/soft_nms.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Soft-NMS 2 | 3 | This code was adapted from a PR for detectron2 submitted by https://github.com/alekseynp 4 | https://github.com/facebookresearch/detectron2/pull/1183/files 5 | 6 | Detectron2 is licensed Apache 2.0, Copyright Facebook Inc. 7 | """ 8 | import torch 9 | from typing import List 10 | 11 | 12 | def pairwise_iou(boxes1, boxes2) -> torch.Tensor: 13 | """ 14 | Given two lists of boxes of size N and M, 15 | compute the IoU (intersection over union) 16 | between __all__ N x M pairs of boxes. 17 | The box order must be (xmin, ymin, xmax, ymax). 18 | Args: 19 | boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. 20 | Returns: 21 | Tensor: IoU, sized [N,M]. 22 | """ 23 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # [N,] 24 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # [M,] 25 | 26 | width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( 27 | boxes1[:, None, :2], boxes2[:, :2] 28 | ) # [N,M,2] 29 | 30 | width_height.clamp_(min=0) # [N,M,2] 31 | inter = width_height.prod(dim=2) # [N,M] 32 | 33 | # handle empty boxes 34 | iou = torch.where( 35 | inter > 0, 36 | inter / (area1[:, None] + area2 - inter), 37 | torch.zeros(1, dtype=inter.dtype, device=inter.device), 38 | ) 39 | return iou 40 | 41 | 42 | def soft_nms( 43 | boxes, 44 | scores, 45 | method_gaussian: bool = True, 46 | sigma: float = 0.5, 47 | iou_threshold: float = .5, 48 | score_threshold: float = 0.005 49 | ): 50 | """ 51 | Soft non-max suppression algorithm. 52 | 53 | Implementation of [Soft-NMS -- Improving Object Detection With One Line of Codec] 54 | (https://arxiv.org/abs/1704.04503) 55 | 56 | Args: 57 | boxes_remain (Tensor[N, ?]): 58 | boxes where NMS will be performed 59 | if Boxes, in (x1, y1, x2, y2) format 60 | if RotatedBoxes, in (x_ctr, y_ctr, width, height, angle_degrees) format 61 | scores_remain (Tensor[N]): 62 | scores for each one of the boxes 63 | method_gaussian (bool): use gaussian method if True, otherwise linear 64 | sigma (float): 65 | parameter for Gaussian penalty function 66 | iou_threshold (float): 67 | iou threshold for applying linear decay. Nt from the paper 68 | re-used as threshold for standard "hard" nms 69 | score_threshold (float): 70 | boxes with scores below this threshold are pruned at each iteration. 71 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2] 72 | 73 | Returns: 74 | tuple(Tensor, Tensor): 75 | [0]: int64 tensor with the indices of the elements that have been kept 76 | by Soft NMS, sorted in decreasing order of scores 77 | [1]: float tensor with the re-scored scores of the elements that were kept 78 | """ 79 | device = boxes.device 80 | boxes_remain = boxes.clone() 81 | scores_remain = scores.clone() 82 | num_elem = scores_remain.size()[0] 83 | idxs = torch.arange(num_elem) 84 | idxs_out = torch.zeros(num_elem, dtype=torch.int64, device=device) 85 | scores_out = torch.zeros(num_elem, dtype=torch.float32, device=device) 86 | count: int = 0 87 | 88 | while scores_remain.numel() > 0: 89 | top_idx = torch.argmax(scores_remain) 90 | idxs_out[count] = idxs[top_idx] 91 | scores_out[count] = scores_remain[top_idx] 92 | count += 1 93 | 94 | top_box = boxes_remain[top_idx] 95 | ious = pairwise_iou(top_box.unsqueeze(0), boxes_remain)[0] 96 | 97 | if method_gaussian: 98 | decay = torch.exp(-torch.pow(ious, 2) / sigma) 99 | else: 100 | decay = torch.ones_like(ious) 101 | decay_mask = ious > iou_threshold 102 | decay[decay_mask] = 1 - ious[decay_mask] 103 | 104 | scores_remain *= decay 105 | keep = scores_remain > score_threshold 106 | keep[top_idx] = torch.tensor(False, device=device) 107 | 108 | boxes_remain = boxes_remain[keep] 109 | scores_remain = scores_remain[keep] 110 | idxs = idxs[keep] 111 | 112 | return idxs_out[:count], scores_out[:count] 113 | 114 | 115 | def batched_soft_nms( 116 | boxes, scores, idxs, 117 | method_gaussian: bool = True, 118 | sigma: float = 0.5, 119 | iou_threshold: float = .5, 120 | score_threshold: float = 0.001): 121 | 122 | """ 123 | Performs soft non-maximum suppression in a batched fashion. 124 | 125 | Each index value correspond to a category, and NMS 126 | will not be applied between elements of different categories. 127 | 128 | Args: 129 | boxes (Tensor[N, 4]): 130 | boxes where NMS will be performed. They 131 | are expected to be in (x1, y1, x2, y2) format 132 | scores (Tensor[N]): 133 | scores for each one of the boxes 134 | idxs (Tensor[N]): 135 | indices of the categories for each one of the boxes. 136 | method (str): 137 | one of ['gaussian', 'linear', 'hard'] 138 | see paper for details. users encouraged not to use "hard", as this is the 139 | same nms available elsewhere in detectron2 140 | sigma (float): 141 | parameter for Gaussian penalty function 142 | iou_threshold (float): 143 | iou threshold for applying linear decay. Nt from the paper 144 | re-used as threshold for standard "hard" nms 145 | score_threshold (float): 146 | boxes with scores below this threshold are pruned at each iteration. 147 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2] 148 | Returns: 149 | tuple(Tensor, Tensor): 150 | [0]: int64 tensor with the indices of the elements that have been kept 151 | by Soft NMS, sorted in decreasing order of scores 152 | [1]: float tensor with the re-scored scores of the elements that were kept 153 | """ 154 | if boxes.numel() == 0: 155 | return ( 156 | torch.empty((0,), dtype=torch.int64, device=boxes.device), 157 | torch.empty((0,), dtype=torch.float32, device=scores.device), 158 | ) 159 | # strategy: in order to perform NMS independently per class. 160 | # we add an offset to all the boxes. The offset is dependent 161 | # only on the class idx, and is large enough so that boxes 162 | # from different classes do not overlap 163 | max_coordinate = boxes.max() 164 | offsets = idxs.to(boxes) * (max_coordinate + 1) 165 | boxes_for_nms = boxes + offsets[:, None] 166 | return soft_nms( 167 | boxes_for_nms, scores, method_gaussian=method_gaussian, sigma=sigma, 168 | iou_threshold=iou_threshold, score_threshold=score_threshold 169 | ) 170 | 171 | -------------------------------------------------------------------------------- /effdet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.1' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.12.1 2 | timm>=0.9.2 3 | torchvision 4 | pyyaml 5 | numpy 6 | 7 | # an update version that fixes some bugs, incl issues with numpy >= 1.18 8 | pycocotools>=2.0.2 9 | 10 | omegaconf>=2.0 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('effdet/version.py').read()) 14 | setup( 15 | name='effdet', 16 | version=__version__, 17 | description='EfficientDet for PyTorch', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/efficientdet-pytorch', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.7', 33 | 'Programming Language :: Python :: 3.8', 34 | 'Programming Language :: Python :: 3.9', 35 | 'Programming Language :: Python :: 3.10', 36 | 'Topic :: Scientific/Engineering', 37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 38 | 'Topic :: Software Development', 39 | 'Topic :: Software Development :: Libraries', 40 | 'Topic :: Software Development :: Libraries :: Python Modules', 41 | ], 42 | 43 | # Note that this is a string of words separated by whitespace, not a list. 44 | keywords='pytorch pretrained efficientdet efficientnet bifpn object detection', 45 | packages=find_packages(exclude=['data']), 46 | install_requires=['torch >= 1.12.1', 'torchvision', 'timm >= 0.9.2', 'pycocotools>=2.0.2', 'omegaconf>=2.0'], 47 | python_requires='>=3.7', 48 | ) 49 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ COCO validation script 3 | 4 | Hacked together by Ross Wightman (https://github.com/rwightman) 5 | """ 6 | import argparse 7 | import time 8 | import torch 9 | import torch.nn.parallel 10 | from contextlib import suppress 11 | 12 | from effdet import create_model, create_evaluator, create_dataset, create_loader 13 | from effdet.data import resolve_input_config 14 | from timm.utils import AverageMeter, setup_default_logging 15 | try: 16 | from timm.layers import set_layer_config 17 | except ImportError: 18 | from timm.models.layers import set_layer_config 19 | 20 | has_apex = False 21 | try: 22 | from apex import amp 23 | has_apex = True 24 | except ImportError: 25 | pass 26 | 27 | has_native_amp = False 28 | try: 29 | if getattr(torch.cuda.amp, 'autocast') is not None: 30 | has_native_amp = True 31 | except AttributeError: 32 | pass 33 | 34 | torch.backends.cudnn.benchmark = True 35 | 36 | 37 | def add_bool_arg(parser, name, default=False, help=''): # FIXME move to utils 38 | dest_name = name.replace('-', '_') 39 | group = parser.add_mutually_exclusive_group(required=False) 40 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 41 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 42 | parser.set_defaults(**{dest_name: default}) 43 | 44 | 45 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 46 | parser.add_argument('root', metavar='DIR', 47 | help='path to dataset root') 48 | parser.add_argument('--dataset', default='coco', type=str, metavar='DATASET', 49 | help='Name of dataset (default: "coco"') 50 | parser.add_argument('--split', default='val', 51 | help='validation split') 52 | parser.add_argument('--model', '-m', metavar='MODEL', default='tf_efficientdet_d1', 53 | help='model architecture (default: tf_efficientdet_d1)') 54 | add_bool_arg(parser, 'redundant-bias', default=None, 55 | help='override model config for redundant bias layers') 56 | add_bool_arg(parser, 'soft-nms', default=None, help='override model config for soft-nms') 57 | parser.add_argument('--num-classes', type=int, default=None, metavar='N', 58 | help='Override num_classes in model config if set. For fine-tuning from pretrained.') 59 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 60 | help='number of data loading workers (default: 4)') 61 | parser.add_argument('-b', '--batch-size', default=128, type=int, 62 | metavar='N', help='mini-batch size (default: 128)') 63 | parser.add_argument('--img-size', default=None, type=int, 64 | metavar='N', help='Input image dimension, uses model default if empty') 65 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 66 | help='Override mean pixel value of dataset') 67 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 68 | help='Override std deviation of of dataset') 69 | parser.add_argument('--interpolation', default='bilinear', type=str, metavar='NAME', 70 | help='Image resize interpolation type (overrides model)') 71 | parser.add_argument('--fill-color', default=None, type=str, metavar='NAME', 72 | help='Image augmentation fill (background) color ("mean" or int)') 73 | parser.add_argument('--log-freq', default=10, type=int, 74 | metavar='N', help='batch logging frequency (default: 10)') 75 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 76 | help='path to latest checkpoint (default: none)') 77 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 78 | help='use pre-trained model') 79 | parser.add_argument('--num-gpu', type=int, default=1, 80 | help='Number of GPUS to use') 81 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 82 | help='disable fast prefetcher') 83 | parser.add_argument('--pin-mem', action='store_true', default=False, 84 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 85 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 86 | help='use ema version of weights if present') 87 | parser.add_argument('--amp', action='store_true', default=False, 88 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') 89 | parser.add_argument('--apex-amp', action='store_true', default=False, 90 | help='Use NVIDIA Apex AMP mixed precision') 91 | parser.add_argument('--native-amp', action='store_true', default=False, 92 | help='Use Native Torch AMP mixed precision') 93 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 94 | help='convert model torchscript for inference') 95 | parser.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', 96 | help="Enable compilation w/ specified backend (default: inductor).") 97 | parser.add_argument('--results', default='', type=str, metavar='FILENAME', 98 | help='JSON filename for evaluation results') 99 | 100 | 101 | def validate(args): 102 | setup_default_logging() 103 | 104 | if args.amp: 105 | if has_native_amp: 106 | args.native_amp = True 107 | elif has_apex: 108 | args.apex_amp = True 109 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." 110 | args.pretrained = args.pretrained or not args.checkpoint # might as well try to validate something 111 | args.prefetcher = not args.no_prefetcher 112 | 113 | # create model 114 | with set_layer_config(scriptable=args.torchscript): 115 | extra_args = {} 116 | if args.img_size is not None: 117 | extra_args = dict(image_size=(args.img_size, args.img_size)) 118 | bench = create_model( 119 | args.model, 120 | bench_task='predict', 121 | num_classes=args.num_classes, 122 | pretrained=args.pretrained, 123 | redundant_bias=args.redundant_bias, 124 | soft_nms=args.soft_nms, 125 | checkpoint_path=args.checkpoint, 126 | checkpoint_ema=args.use_ema, 127 | **extra_args, 128 | ) 129 | model_config = bench.config 130 | 131 | param_count = sum([m.numel() for m in bench.parameters()]) 132 | print('Model %s created, param count: %d' % (args.model, param_count)) 133 | 134 | bench = bench.cuda() 135 | 136 | if args.torchscript: 137 | assert not args.apex_amp, \ 138 | 'Cannot use APEX AMP with torchscripted model, force native amp with `--native-amp` flag' 139 | bench = torch.jit.script(bench) 140 | elif args.torchcompile: 141 | bench = torch.compile(bench, backend=args.torchcompile) 142 | 143 | amp_autocast = suppress 144 | if args.apex_amp: 145 | bench = amp.initialize(bench, opt_level='O1') 146 | print('Using NVIDIA APEX AMP. Validating in mixed precision.') 147 | elif args.native_amp: 148 | amp_autocast = torch.cuda.amp.autocast 149 | print('Using native Torch AMP. Validating in mixed precision.') 150 | else: 151 | print('AMP not enabled. Validating in float32.') 152 | 153 | if args.num_gpu > 1: 154 | bench = torch.nn.DataParallel(bench, device_ids=list(range(args.num_gpu))) 155 | 156 | dataset = create_dataset(args.dataset, args.root, args.split) 157 | input_config = resolve_input_config(args, model_config) 158 | loader = create_loader( 159 | dataset, 160 | input_size=input_config['input_size'], 161 | batch_size=args.batch_size, 162 | use_prefetcher=args.prefetcher, 163 | interpolation=input_config['interpolation'], 164 | fill_color=input_config['fill_color'], 165 | mean=input_config['mean'], 166 | std=input_config['std'], 167 | num_workers=args.workers, 168 | pin_mem=args.pin_mem, 169 | ) 170 | 171 | evaluator = create_evaluator(args.dataset, dataset, pred_yxyx=False) 172 | bench.eval() 173 | batch_time = AverageMeter() 174 | end = time.time() 175 | last_idx = len(loader) - 1 176 | with torch.no_grad(): 177 | for i, (input, target) in enumerate(loader): 178 | with amp_autocast(): 179 | output = bench(input, img_info=target) 180 | evaluator.add_predictions(output, target) 181 | 182 | # measure elapsed time 183 | batch_time.update(time.time() - end) 184 | end = time.time() 185 | if i % args.log_freq == 0 or i == last_idx: 186 | print( 187 | f'Test: [{i:>4d}/{len(loader)}] ' 188 | f'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {input.size(0) / batch_time.avg:>7.2f}/s) ' 189 | ) 190 | 191 | mean_ap = 0. 192 | if dataset.parser.has_labels: 193 | mean_ap = evaluator.evaluate(output_result_file=args.results) 194 | else: 195 | evaluator.save(args.results) 196 | 197 | return mean_ap 198 | 199 | 200 | def main(): 201 | args = parser.parse_args() 202 | validate(args) 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | 208 | --------------------------------------------------------------------------------