├── .gitattributes ├── .github └── workflows │ ├── ci.yml │ ├── publish_pypi.yml │ └── ruff.yaml ├── .gitignore ├── .markdownlint-cli2.yaml ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── demo ├── demo_data │ ├── prediction_visual.png │ ├── small-vehicles1.jpeg │ ├── terrain2.png │ └── terrain2_coco.json ├── inference_for_detectron2.ipynb ├── inference_for_huggingface.ipynb ├── inference_for_mmdetection.ipynb ├── inference_for_rtdetr.ipynb ├── inference_for_sparse_yolov5.ipynb ├── inference_for_torchvision.ipynb ├── inference_for_ultralytics.ipynb ├── inference_for_yolov5.ipynb ├── inference_for_yolov8_onnx.ipynb └── slicing.ipynb ├── docs ├── README.md ├── cli.md ├── coco.md ├── fiftyone.md ├── predict.md └── slicing.md ├── pyproject.toml ├── resources ├── hf_spaces_badge.svg └── sliced_inference.gif ├── sahi ├── __init__.py ├── annotation.py ├── auto_model.py ├── cli.py ├── models │ ├── __init__.py │ ├── base.py │ ├── detectron2.py │ ├── huggingface.py │ ├── mmdet.py │ ├── rtdetr.py │ ├── torchvision.py │ ├── ultralytics.py │ ├── yolov5.py │ ├── yolov5sparse.py │ └── yolov8onnx.py ├── postprocess │ ├── __init__.py │ ├── combine.py │ ├── legacy │ │ ├── __init__.py │ │ └── combine.py │ └── utils.py ├── predict.py ├── prediction.py ├── scripts │ ├── __init__.py │ ├── coco2fiftyone.py │ ├── coco2yolo.py │ ├── coco_error_analysis.py │ ├── coco_evaluation.py │ ├── predict.py │ ├── predict_fiftyone.py │ └── slice_coco.py ├── slicing.py └── utils │ ├── __init__.py │ ├── coco.py │ ├── compatibility.py │ ├── cv.py │ ├── detectron2.py │ ├── fiftyone.py │ ├── file.py │ ├── huggingface.py │ ├── import_utils.py │ ├── mmdet.py │ ├── rtdetr.py │ ├── shapely.py │ ├── sparseyolov5.py │ ├── torch.py │ ├── torchvision.py │ ├── ultralytics.py │ ├── yolov5.py │ └── yolov8onnx.py ├── scripts ├── __init__.py ├── run_code_style.py └── utils.py ├── tests ├── __init__.py ├── check_commandline.sh ├── check_dependencies.sh ├── data │ ├── coco_evaluate │ │ ├── dataset.json │ │ └── result.json │ ├── coco_utils │ │ ├── coco_class_names.yaml │ │ ├── combined_coco.json │ │ ├── modified_terrain1_coco.json │ │ ├── modified_terrain2_coco.json │ │ ├── terrain1.jpg │ │ ├── terrain1.json │ │ ├── terrain1_coco.json │ │ ├── terrain2.json │ │ ├── terrain2.png │ │ ├── terrain2_coco.json │ │ ├── terrain2_gray.png │ │ ├── terrain3.json │ │ ├── terrain3.png │ │ ├── terrain3_coco.json │ │ ├── terrain4.png │ │ ├── terrain_all_coco.json │ │ └── visdrone2019-det-train-first50image.json │ ├── models │ │ ├── mmdet │ │ │ ├── _base_ │ │ │ │ ├── datasets │ │ │ │ │ ├── cityscapes_detection.py │ │ │ │ │ ├── cityscapes_instance.py │ │ │ │ │ ├── coco_detection.py │ │ │ │ │ ├── coco_instance.py │ │ │ │ │ ├── coco_instance_semantic.py │ │ │ │ │ ├── coco_panoptic.py │ │ │ │ │ ├── deepfashion.py │ │ │ │ │ ├── lvis_v0.5_instance.py │ │ │ │ │ ├── lvis_v1_instance.py │ │ │ │ │ ├── objects365v1_detection.py │ │ │ │ │ ├── objects365v2_detection.py │ │ │ │ │ ├── openimages_detection.py │ │ │ │ │ ├── semi_coco_detection.py │ │ │ │ │ ├── voc0712.py │ │ │ │ │ └── wider_face.py │ │ │ │ ├── default_runtime.py │ │ │ │ ├── models │ │ │ │ │ ├── cascade-mask-rcnn_r50_fpn.py │ │ │ │ │ ├── cascade-rcnn_r50_fpn.py │ │ │ │ │ ├── fast-rcnn_r50_fpn.py │ │ │ │ │ ├── faster-rcnn_r50-caffe-c4.py │ │ │ │ │ ├── faster-rcnn_r50-caffe-dc5.py │ │ │ │ │ ├── faster-rcnn_r50_fpn.py │ │ │ │ │ ├── mask-rcnn_r50-caffe-c4.py │ │ │ │ │ ├── mask-rcnn_r50_fpn.py │ │ │ │ │ ├── retinanet_r50_fpn.py │ │ │ │ │ ├── rpn_r50-caffe-c4.py │ │ │ │ │ ├── rpn_r50_fpn.py │ │ │ │ │ └── ssd300.py │ │ │ │ └── schedules │ │ │ │ │ ├── schedule_1x.py │ │ │ │ │ ├── schedule_20e.py │ │ │ │ │ └── schedule_2x.py │ │ │ ├── cascade_mask_rcnn │ │ │ │ └── cascade-mask-rcnn_r50_fpn_1x_coco.py │ │ │ ├── retinanet │ │ │ │ ├── retinanet_r50_fpn_1x_coco.py │ │ │ │ └── retinanet_tta.py │ │ │ └── yolox │ │ │ │ ├── yolox_s_8xb8-300e_coco.py │ │ │ │ ├── yolox_tiny_8xb8-300e_coco.py │ │ │ │ └── yolox_tta.py │ │ ├── mmdet_cascade_mask_rcnn │ │ │ ├── cascade-mask-rcnn_r50_fpn.py │ │ │ ├── cascade-mask-rcnn_r50_fpn_1x_coco.py │ │ │ ├── cascade_mask_rcnn_r50_fpn.py │ │ │ ├── cascade_mask_rcnn_r50_fpn_1x_coco.py │ │ │ ├── cascade_mask_rcnn_r50_fpn_1x_coco_v280.py │ │ │ ├── cascade_mask_rcnn_r50_fpn_v280.py │ │ │ ├── coco_instance.py │ │ │ ├── default_runtime.py │ │ │ └── schedule_1x.py │ │ ├── mmdet_retinanet │ │ │ ├── coco_detection.py │ │ │ ├── default_runtime.py │ │ │ ├── retinanet_r50_fpn.py │ │ │ ├── retinanet_r50_fpn_1x_coco.py │ │ │ ├── retinanet_r50_fpn_1x_coco_v280.py │ │ │ ├── retinanet_r50_fpn_v280.py │ │ │ └── schedule_1x.py │ │ ├── mmdet_yolox │ │ │ └── yolox_tiny_8x8_300e_coco.py │ │ └── torchvision │ │ │ ├── fasterrcnn_resnet50_fpn.yaml │ │ │ └── ssd300_vgg16.yaml │ └── small-vehicles1.jpeg ├── test_annotation.py ├── test_autoslice.py ├── test_cocoutils.py ├── test_cvutils.py ├── test_detectron2.py ├── test_fileutils.py ├── test_highlevelapi.py ├── test_huggingfacemodel.py ├── test_mmdetectionmodel.py ├── test_postprocessutils.py ├── test_predict.py ├── test_prediction.py ├── test_rtdetr.py ├── test_shapelyutils.py ├── test_slicing.py ├── test_sparseyolov5model.py ├── test_torchutils.py ├── test_torchvision.py ├── test_ultralyticsmodel.py └── test_yolov8onnx.py └── uv.lock /.gitattributes: -------------------------------------------------------------------------------- 1 | # this drop notebooks from GitHub language stats 2 | *.ipynb linguist-vendored 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI with uv 2 | on: 3 | push: 4 | branches: [main] 5 | paths-ignore: 6 | - "**.md" 7 | - "**.ipynb" 8 | - "**.cff" 9 | 10 | pull_request: 11 | branches: [main] 12 | paths-ignore: 13 | - "**.md" 14 | - "**.ipynb" 15 | - "**.cff" 16 | 17 | schedule: 18 | - cron: "0 0 * * *" # Runs at 00:00 UTC every day 19 | 20 | workflow_dispatch: # allow running sync via github ui button 21 | 22 | jobs: 23 | ci: 24 | runs-on: ubuntu-latest 25 | strategy: 26 | matrix: 27 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 28 | steps: 29 | - uses: actions/checkout@v4 30 | - name: Setup uv python package manager 31 | uses: astral-sh/setup-uv@v5 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | enable-cache: true 35 | prune-cache: false 36 | # For python 3.8 and 3.9, it does no suffice to install opencv-python-headless 37 | # https://itsmycode.com/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directory/ 38 | - name: Cache apt packages 39 | uses: actions/cache@v4 40 | if: ${{ env.ACT }} 41 | with: 42 | path: /var/cache/apt 43 | key: apt-cache 44 | - name: Install libgl 45 | run: sudo apt-get update && sudo apt-get install -y libgl1 46 | - name: Debug uv sync 47 | run: | 48 | set -e 49 | uv sync -v 50 | - name: Verify .venv Activation 51 | run: | 52 | if [ ! -d ".venv" ]; then 53 | echo ".venv directory not found" 54 | exit 1 55 | fi 56 | source .venv/bin/activate 57 | - name: Install sahi from PyPI 58 | if: github.event_name == 'schedule' 59 | run: | 60 | rm -fr sahi 61 | uv pip install --force-reinstall sahi 62 | uv pip install "numpy<2" 63 | uv pip show sahi 64 | source .venv/bin/activate 65 | python -c "import sahi; print(sahi.__version__)" 66 | - name: Test with python ${{ matrix.python-version }} 67 | run: | 68 | source .venv/bin/activate 69 | pytest --capture=no 70 | - name: Test SAHI CLI 71 | run: | 72 | source .venv/bin/activate 73 | set -e 74 | sahi predict --no_sliced_prediction --model_type ultralytics --source tests/data/coco_utils/terrain1.jpg --novisual --model_path tests/data/models/ultralytics/yolo11n.pt --image_size 320 75 | sahi predict --model_type ultralytics --source tests/data/ --novisual --model_path tests/data/models/ultralytics/yolo11n.pt --image_size 320 76 | sahi predict --model_type ultralytics --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/ultralytics/yolo11n.pt --image_size 320 77 | sahi predict --model_type ultralytics --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/ultralytics/yolo11n.pt --image_size 320 78 | # coco yolov5 79 | sahi coco yolov5 --image_dir tests/data/coco_utils/ --dataset_json_path tests/data/coco_utils/combined_coco.json --train_split 0.9 80 | # coco evaluate 81 | sahi coco evaluate --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json 82 | # coco analyse 83 | sahi coco analyse --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json --out_dir tests/data/coco_evaluate/ 84 | - name: Test SAHI CLI with MMCV 85 | if: matrix.python-version != '3.12' 86 | run: | 87 | source .venv/bin/activate 88 | set -e 89 | sahi predict --model_type mmdet --source tests/data/ --novisual --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320 90 | sahi predict --model_type mmdet --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320 91 | sahi predict --model_type mmdet --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320 92 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [published, edited] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Setup uv python package manager 13 | uses: astral-sh/setup-uv@v5 14 | with: 15 | enable-cache: true 16 | prune-cache: false 17 | - name: Build and publish 18 | env: 19 | UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }} 20 | run: | 21 | uv build 22 | uv publish 23 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yaml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [push] 3 | jobs: 4 | ruff-format: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: astral-sh/ruff-action@v3 9 | with: 10 | args: "format --check" 11 | version: "latest" 12 | ruff: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: astral-sh/ruff-action@v3 17 | with: 18 | version: "latest" 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.pkl 4 | *.pth 5 | *.pt 6 | *.onnx 7 | weights* 8 | .vscode 9 | .idea 10 | runs 11 | 12 | # outputs 13 | outputs 14 | sliced_prediction_data 15 | 16 | # mmdetection 17 | mmdetection/build 18 | mmdetection/demo 19 | mmdetection/experiments 20 | 21 | # mac 22 | .DS_Store 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | pip-wheel-metadata/ 47 | share/python-wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | *.py,cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | cover/ 77 | 78 | # Translations 79 | *.mo 80 | *.pot 81 | 82 | # Django stuff: 83 | *.log 84 | local_settings.py 85 | db.sqlite3 86 | db.sqlite3-journal 87 | 88 | # Flask stuff: 89 | instance/ 90 | .webassets-cache 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | 98 | # PyBuilder 99 | .pybuilder/ 100 | target/ 101 | 102 | # Jupyter Notebook 103 | .ipynb_checkpoints 104 | 105 | # IPython 106 | profile_default/ 107 | ipython_config.py 108 | 109 | # pyenv 110 | # For a library or package, you might want to ignore these files since the code is 111 | # intended to run in multiple environments; otherwise, check them in: 112 | # .python-version 113 | 114 | # pipenv 115 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 116 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 117 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 118 | # install all needed dependencies. 119 | #Pipfile.lock 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # Elastic Beanstalk Files 165 | .elasticbeanstalk/* 166 | !.elasticbeanstalk/*.cfg.yml 167 | !.elasticbeanstalk/*.global.yml 168 | tests/data 169 | 170 | .archive 171 | .python-version 172 | -------------------------------------------------------------------------------- /.markdownlint-cli2.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | MD013: false 3 | MD033: false 4 | MD041: false 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.9.5 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi ] 9 | args: [--fix] 10 | # Run import organizer 11 | - id: ruff 12 | types_or: [ python, pyi ] 13 | args: [--select, I, --fix] 14 | # Run the formatter. 15 | - id: ruff-format 16 | types_or: [ python, pyi ] 17 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this package, please consider citing it." 3 | authors: 4 | - family-names: "Akyon" 5 | given-names: "Fatih Cagatay" 6 | - family-names: "Cengiz" 7 | given-names: "Cemil" 8 | - family-names: "Altinuc" 9 | given-names: "Sinan Onur" 10 | - family-names: "Cavusoglu" 11 | given-names: "Devrim" 12 | - family-names: "Sahin" 13 | given-names: "Kadir" 14 | - family-names: "Eryuksel" 15 | given-names: "Ogulcan" 16 | title: "SAHI: A lightweight vision library for performing large scale object detection and instance segmentation" 17 | preferred-citation: 18 | type: article 19 | title: "Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection" 20 | doi: 10.1109/ICIP46576.2022.9897990 21 | url: https://ieeexplore.ieee.org/document/9897990 22 | journal: 2022 IEEE International Conference on Image Processing (ICIP) 23 | authors: 24 | - family-names: "Akyon" 25 | given-names: "Fatih Cagatay" 26 | - family-names: "Altinuc" 27 | given-names: "Sinan Onur" 28 | - family-names: "Temizel" 29 | given-names: "Alptekin" 30 | year: 2022 31 | start: 966 32 | end: 970 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 obss 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /demo/demo_data/prediction_visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/demo/demo_data/prediction_visual.png -------------------------------------------------------------------------------- /demo/demo_data/small-vehicles1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/demo/demo_data/small-vehicles1.jpeg -------------------------------------------------------------------------------- /demo/demo_data/terrain2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/demo/demo_data/terrain2.png -------------------------------------------------------------------------------- /demo/demo_data/terrain2_coco.json: -------------------------------------------------------------------------------- 1 | { 2 | "images": [ 3 | { 4 | "height": 682, 5 | "width": 1024, 6 | "id": 1, 7 | "file_name": "terrain2.png" 8 | } 9 | ], 10 | "categories": [ 11 | { 12 | "supercategory": "car", 13 | "id": 1, 14 | "name": "car" 15 | } 16 | ], 17 | "annotations": [ 18 | { 19 | "iscrowd": 0, 20 | "image_id": 1, 21 | "bbox": [ 22 | 218.0, 23 | 448.0, 24 | 222.0, 25 | 161.0 26 | ], 27 | "segmentation": [ 28 | [ 29 | 218.0, 30 | 448.0, 31 | 440.0, 32 | 448.0, 33 | 440.0, 34 | 609.0, 35 | 218.0, 36 | 609.0 37 | ] 38 | ], 39 | "category_id": 1, 40 | "id": 1, 41 | "area": 698368 42 | }, 43 | { 44 | "iscrowd": 0, 45 | "image_id": 1, 46 | "bbox": [ 47 | 501.0, 48 | 451.0, 49 | 121.0, 50 | 92.0 51 | ], 52 | "segmentation": [ 53 | [ 54 | 501.0, 55 | 451.0, 56 | 622.0, 57 | 451.0, 58 | 622.0, 59 | 543.0, 60 | 501.0, 61 | 543.0 62 | ] 63 | ], 64 | "category_id": 1, 65 | "id": 2, 66 | "area": 698368 67 | }, 68 | { 69 | "iscrowd": 0, 70 | "image_id": 1, 71 | "bbox": [ 72 | 634.0, 73 | 437.0, 74 | 81.0, 75 | 56.0 76 | ], 77 | "segmentation": [ 78 | [ 79 | 634.0, 80 | 437.0, 81 | 715.0, 82 | 437.0, 83 | 715.0, 84 | 493.0, 85 | 634.0, 86 | 493.0 87 | ] 88 | ], 89 | "category_id": 1, 90 | "id": 3, 91 | "area": 698368 92 | }, 93 | { 94 | "iscrowd": 0, 95 | "image_id": 1, 96 | "bbox": [ 97 | 725.0, 98 | 423.0, 99 | 70.0, 100 | 51.0 101 | ], 102 | "segmentation": [ 103 | [ 104 | 725.0, 105 | 423.0, 106 | 795.0, 107 | 423.0, 108 | 795.0, 109 | 474.0, 110 | 725.0, 111 | 474.0 112 | ] 113 | ], 114 | "category_id": 1, 115 | "id": 4, 116 | "area": 698368 117 | }, 118 | { 119 | "iscrowd": 0, 120 | "image_id": 1, 121 | "bbox": [ 122 | 791.0, 123 | 404.0, 124 | 40.0, 125 | 47.0 126 | ], 127 | "segmentation": [ 128 | [ 129 | 791.0, 130 | 404.0, 131 | 831.0, 132 | 404.0, 133 | 831.0, 134 | 451.0, 135 | 791.0, 136 | 451.0 137 | ] 138 | ], 139 | "category_id": 1, 140 | "id": 5, 141 | "area": 698368 142 | } 143 | ] 144 | } 145 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # SAHI Documentation 2 | 3 | Welcome to the SAHI documentation! This directory contains detailed guides and tutorials for using SAHI's various features. Below is an overview of each documentation file and what you'll find in it. 4 | 5 | ## Core Documentation Files 6 | 7 | ### [Prediction Utilities](predict.md) 8 | - Detailed guide for performing object detection inference 9 | - Standard and sliced inference examples 10 | - Batch prediction usage 11 | - Class exclusion during inference 12 | - Visualization parameters and export formats 13 | - Interactive examples with various model integrations (YOLOv8, MMDetection, etc.) 14 | 15 | ### [Slicing Utilities](slicing.md) 16 | - Guide for slicing large images and datasets 17 | - Image slicing examples 18 | - COCO dataset slicing examples 19 | - Interactive demo notebook reference 20 | 21 | ### [COCO Utilities](coco.md) 22 | - Comprehensive guide for working with COCO format datasets 23 | - Dataset creation and manipulation 24 | - Slicing COCO datasets 25 | - Dataset splitting (train/val) 26 | - Category filtering and updates 27 | - Area-based filtering 28 | - Dataset merging 29 | - Format conversion (COCO ↔ YOLO) 30 | - Dataset sampling utilities 31 | - Statistics calculation 32 | - Result validation 33 | 34 | ### [CLI Commands](cli.md) 35 | - Complete reference for SAHI command-line interface 36 | - Prediction commands 37 | - FiftyOne integration 38 | - COCO dataset operations 39 | - Environment information 40 | - Version checking 41 | - Custom script usage 42 | 43 | ### [FiftyOne Integration](fiftyone.md) 44 | - Guide for visualizing and analyzing predictions with FiftyOne 45 | - Dataset visualization 46 | - Result exploration 47 | - Interactive analysis 48 | 49 | ## Interactive Examples 50 | 51 | All documentation files are complemented by interactive Jupyter notebooks in the [demo directory](../demo/): 52 | - `slicing.ipynb` - Slicing operations demonstration 53 | - `inference_for_ultralytics.ipynb` - YOLOv8/YOLO11/YOLO12 integration 54 | - `inference_for_yolov5.ipynb` - YOLOv5 integration 55 | - `inference_for_mmdetection.ipynb` - MMDetection integration 56 | - `inference_for_huggingface.ipynb` - HuggingFace models integration 57 | - `inference_for_torchvision.ipynb` - TorchVision models integration 58 | - `inference_for_rtdetr.ipynb` - RT-DETR integration 59 | - `inference_for_sparse_yolov5.ipynb` - DeepSparse optimized inference 60 | 61 | ## Getting Started 62 | 63 | If you're new to SAHI: 64 | 65 | 1. Start with the [prediction utilities](predict.md) to understand basic inference 66 | 2. Explore the [slicing utilities](slicing.md) to learn about processing large images 67 | 3. Check out the [CLI commands](cli.md) for command-line usage 68 | 4. Dive into [COCO utilities](coco.md) for dataset operations 69 | 5. Try the interactive notebooks in the [demo directory](../demo/) for hands-on experience 70 | -------------------------------------------------------------------------------- /docs/fiftyone.md: -------------------------------------------------------------------------------- 1 | # Fiftyone Utilities 2 | 3 | - Explore COCO dataset via FiftyOne app: 4 | 5 | Supported version: `pip install fiftyone>=0.14.2<0.15.0` 6 | 7 | ```python 8 | from sahi.utils.fiftyone import launch_fiftyone_app 9 | 10 | # launch fiftyone app: 11 | session = launch_fiftyone_app(coco_image_dir, coco_json_path) 12 | 13 | # close fiftyone app: 14 | session.close() 15 | ``` 16 | 17 | - Convert predictions to FiftyOne detection: 18 | 19 | ```python 20 | from sahi import get_sliced_prediction 21 | 22 | # perform sliced prediction 23 | result = get_sliced_prediction( 24 | image, 25 | detection_model, 26 | slice_height = 256, 27 | slice_width = 256, 28 | overlap_height_ratio = 0.2, 29 | overlap_width_ratio = 0.2 30 | ) 31 | 32 | # convert detections into fiftyone detection format 33 | fiftyone_detections = result.to_fiftyone_detections() 34 | ``` 35 | 36 | - Explore detection results in Fiftyone UI: 37 | 38 | ```bash 39 | sahi coco fiftyone --image_dir dir/to/images --dataset_json_path dataset.json cocoresult1.json cocoresult2.json 40 | ``` 41 | 42 | will open a FiftyOne app that visualizes the given dataset and 2 detection results. 43 | 44 | Specify IOU threshold for FP/TP by `--iou_threshold 0.5` argument 45 | -------------------------------------------------------------------------------- /docs/slicing.md: -------------------------------------------------------------------------------- 1 | # Slicing Utilities 2 | 3 | - Slice an image: 4 | 5 | ```python 6 | from sahi.slicing import slice_image 7 | 8 | slice_image_result = slice_image( 9 | image=image_path, 10 | output_file_name=output_file_name, 11 | output_dir=output_dir, 12 | slice_height=256, 13 | slice_width=256, 14 | overlap_height_ratio=0.2, 15 | overlap_width_ratio=0.2, 16 | ) 17 | ``` 18 | 19 | - Slice a COCO formatted dataset: 20 | 21 | ```python 22 | from sahi.slicing import slice_coco 23 | 24 | coco_dict, coco_path = slice_coco( 25 | coco_annotation_file_path=coco_annotation_file_path, 26 | image_dir=image_dir, 27 | slice_height=256, 28 | slice_width=256, 29 | overlap_height_ratio=0.2, 30 | overlap_width_ratio=0.2, 31 | ) 32 | ``` 33 | 34 | # Interactive Demo 35 | 36 | Want to experiment with different slicing parameters and see their effects? Check out our [interactive Jupyter notebook](../demo/slicing.ipynb) that demonstrates these slicing operations in action. 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sahi" 3 | version = "0.11.23" 4 | readme = "README.md" 5 | description = "A vision library for performing sliced inference on large images/small objects" 6 | requires-python = ">=3.8" 7 | license = "MIT" 8 | license-files = ["./LICENSE"] 9 | dependencies = [ 10 | "opencv-python<=4.10.0.84", 11 | "shapely>=2.0.0", 12 | "tqdm>=4.48.2", 13 | "pillow>=8.2.0", 14 | "pybboxes==0.1.6", 15 | "pyyaml", 16 | "fire", 17 | "terminaltables", 18 | "requests", 19 | "click", 20 | ] 21 | classifiers = [ 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | "Intended Audience :: Developers", 25 | "Intended Audience :: Science/Research", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3.11", 31 | "Programming Language :: Python :: 3.12", 32 | "Topic :: Software Development :: Libraries", 33 | "Topic :: Software Development :: Libraries :: Python Modules", 34 | ] 35 | 36 | [project.urls] 37 | homepage = "https://github.com/obss/sahi" 38 | 39 | [project.scripts] 40 | sahi = "sahi.cli:app" 41 | 42 | [tool.uv] 43 | find-links = [ 44 | "https://download.openmmlab.com/mmcv/dist/cpu/torch2.1.0/index.html", 45 | ] 46 | default-groups = ["dev", "ci"] 47 | 48 | [dependency-groups] 49 | dev = [ 50 | "pytest", 51 | "ruff", 52 | "pre-commit>=2.0", 53 | "jupyterlab>=3.0.14", 54 | "matplotlib-stubs>=0.2.0", 55 | ] 56 | ci = [ 57 | # pytorch should be present for all python versions 58 | "torch==2.6.0+cpu;python_version>='3.12'", 59 | "torchvision==0.21.0+cpu;python_version>='3.12'", 60 | "torch==2.1.2+cpu;python_version<'3.12'", 61 | "torchvision==0.16.2+cpu;python_version<'3.12'", 62 | # mmdet is supported for python<3.12 63 | "mmengine;python_version<'3.12'", 64 | "mmcv==2.1.0;python_version<'3.12'", 65 | "mmdet==3.3.0;python_version<'3.12'", 66 | # deepsparse is only available for python<3.12 67 | "deepsparse;python_version<'3.12'", 68 | "onnxruntime<1.20;python_version<'3.12'", 69 | "onnx>1.16;python_version>='3.12'", 70 | "onnxruntime;python_version>='3.12'", 71 | # transformers is supported for python>=3.9 72 | "transformers>=4.49.0;python_version>='3.9'", 73 | # These are available for all python versions 74 | "pycocotools>=2.0.7", 75 | "ultralytics>=8.3.86", 76 | "scikit-image", 77 | "fiftyone", 78 | ] 79 | 80 | [[tool.uv.index]] 81 | name = "pytorch-cpu" 82 | url = "https://download.pytorch.org/whl/cpu" 83 | explicit = true 84 | 85 | [tool.uv.sources] 86 | torch = [{ index = "pytorch-cpu" }] 87 | torchvision = [{ index = "pytorch-cpu" }] 88 | 89 | [build-system] 90 | requires = ["hatchling"] 91 | build-backend = "hatchling.build" 92 | 93 | [tool.ruff] 94 | line-length = 120 95 | exclude = ["**/__init__.py", ".git", "__pycache__", "*.ipynb"] 96 | 97 | [tool.pytest.ini_options] 98 | minversion = "6.0" 99 | addopts = ["--import-mode=importlib", "--no-header"] 100 | pythonpath = ["."] 101 | 102 | [tool.typos.default] 103 | extend-ignore-identifiers-re = ["fo"] 104 | -------------------------------------------------------------------------------- /resources/hf_spaces_badge.svg: -------------------------------------------------------------------------------- 1 | HF SpacesHF Spaces -------------------------------------------------------------------------------- /resources/sliced_inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/resources/sliced_inference.gif -------------------------------------------------------------------------------- /sahi/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.11.23" 2 | 3 | from sahi.annotation import BoundingBox, Category, Mask 4 | from sahi.auto_model import AutoDetectionModel 5 | from sahi.models.base import DetectionModel 6 | from sahi.prediction import ObjectPrediction 7 | -------------------------------------------------------------------------------- /sahi/auto_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from sahi.models.base import DetectionModel 4 | from sahi.utils.file import import_model_class 5 | 6 | MODEL_TYPE_TO_MODEL_CLASS_NAME = { 7 | "ultralytics": "UltralyticsDetectionModel", 8 | "rtdetr": "RTDetrDetectionModel", 9 | "mmdet": "MmdetDetectionModel", 10 | "yolov5": "Yolov5DetectionModel", 11 | "detectron2": "Detectron2DetectionModel", 12 | "huggingface": "HuggingfaceDetectionModel", 13 | "torchvision": "TorchVisionDetectionModel", 14 | "yolov5sparse": "Yolov5SparseDetectionModel", 15 | "yolov8onnx": "Yolov8OnnxDetectionModel", 16 | } 17 | 18 | ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"] 19 | 20 | 21 | class AutoDetectionModel: 22 | @staticmethod 23 | def from_pretrained( 24 | model_type: str, 25 | model_path: Optional[str] = None, 26 | model: Optional[Any] = None, 27 | config_path: Optional[str] = None, 28 | device: Optional[str] = None, 29 | mask_threshold: float = 0.5, 30 | confidence_threshold: float = 0.3, 31 | category_mapping: Optional[Dict] = None, 32 | category_remapping: Optional[Dict] = None, 33 | load_at_init: bool = True, 34 | image_size: Optional[int] = None, 35 | **kwargs, 36 | ) -> DetectionModel: 37 | """ 38 | Loads a DetectionModel from given path. 39 | 40 | Args: 41 | model_type: str 42 | Name of the detection framework (example: "ultralytics", "huggingface", "torchvision") 43 | model_path: str 44 | Path of the detection model (ex. 'model.pt') 45 | config_path: str 46 | Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py') 47 | device: str 48 | Device, "cpu" or "cuda:0" 49 | mask_threshold: float 50 | Value to threshold mask pixels, should be between 0 and 1 51 | confidence_threshold: float 52 | All predictions with score < confidence_threshold will be discarded 53 | category_mapping: dict: str to str 54 | Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} 55 | category_remapping: dict: str to int 56 | Remap category ids based on category names, after performing inference e.g. {"car": 3} 57 | load_at_init: bool 58 | If True, automatically loads the model at initialization 59 | image_size: int 60 | Inference input size. 61 | 62 | Returns: 63 | Returns an instance of a DetectionModel 64 | 65 | Raises: 66 | ImportError: If given {model_type} framework is not installed 67 | """ 68 | if model_type in ULTRALYTICS_MODEL_NAMES: 69 | model_type = "ultralytics" 70 | model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] 71 | DetectionModel = import_model_class(model_type, model_class_name) 72 | 73 | return DetectionModel( 74 | model_path=model_path, 75 | model=model, 76 | config_path=config_path, 77 | device=device, 78 | mask_threshold=mask_threshold, 79 | confidence_threshold=confidence_threshold, 80 | category_mapping=category_mapping, 81 | category_remapping=category_remapping, 82 | load_at_init=load_at_init, 83 | image_size=image_size, 84 | **kwargs, 85 | ) 86 | -------------------------------------------------------------------------------- /sahi/cli.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from sahi import __version__ as sahi_version 4 | from sahi.predict import predict, predict_fiftyone 5 | from sahi.scripts.coco2fiftyone import main as coco2fiftyone 6 | from sahi.scripts.coco2yolo import main as coco2yolo 7 | from sahi.scripts.coco_error_analysis import analyse 8 | from sahi.scripts.coco_evaluation import evaluate 9 | from sahi.scripts.slice_coco import slice 10 | from sahi.utils.import_utils import print_environment_info 11 | 12 | coco_app = { 13 | "evaluate": evaluate, 14 | "analyse": analyse, 15 | "fiftyone": coco2fiftyone, 16 | "slice": slice, 17 | "yolo": coco2yolo, 18 | "yolov5": coco2yolo, 19 | } 20 | 21 | sahi_app = { 22 | "predict": predict, 23 | "predict-fiftyone": predict_fiftyone, 24 | "coco": coco_app, 25 | "version": sahi_version, 26 | "env": print_environment_info, 27 | } 28 | 29 | 30 | def app() -> None: 31 | """Cli app.""" 32 | fire.Fire(sahi_app) 33 | 34 | 35 | if __name__ == "__main__": 36 | app() 37 | -------------------------------------------------------------------------------- /sahi/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, detectron2, huggingface, mmdet, torchvision, ultralytics, yolov5, yolov8onnx 2 | -------------------------------------------------------------------------------- /sahi/models/rtdetr.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by AnNT, 2023. 3 | 4 | import logging 5 | 6 | from sahi.models.ultralytics import UltralyticsDetectionModel 7 | from sahi.utils.import_utils import check_requirements 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class RTDetrDetectionModel(UltralyticsDetectionModel): 13 | def check_dependencies(self) -> None: 14 | check_requirements(["ultralytics"]) 15 | 16 | def load_model(self): 17 | """ 18 | Detection model is initialized and set to self.model. 19 | """ 20 | 21 | from ultralytics import RTDETR 22 | 23 | try: 24 | model = RTDETR(self.model_path) 25 | model.to(self.device) 26 | 27 | self.set_model(model) 28 | except Exception as e: 29 | raise TypeError("model_path is not a valid rtdet model path: ", e) 30 | -------------------------------------------------------------------------------- /sahi/postprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/sahi/postprocess/__init__.py -------------------------------------------------------------------------------- /sahi/postprocess/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /sahi/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/sahi/scripts/__init__.py -------------------------------------------------------------------------------- /sahi/scripts/coco2fiftyone.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import fire 5 | 6 | from sahi.utils.file import load_json 7 | 8 | 9 | def main( 10 | image_dir: str, 11 | dataset_json_path: str, 12 | *result_json_paths, 13 | iou_thresh: float = 0.5, 14 | ): 15 | """ 16 | Args: 17 | image_dir (str): directory for coco images 18 | dataset_json_path (str): file path for the coco dataset json file 19 | result_json_paths (str): one or more paths for the coco result json file 20 | iou_thresh (float): iou threshold for coco evaluation 21 | """ 22 | 23 | from fiftyone.utils.coco import add_coco_labels 24 | 25 | from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file, fo 26 | 27 | coco_result_list = [] 28 | result_name_list = [] 29 | if result_json_paths: 30 | for result_json_path in result_json_paths: 31 | coco_result = load_json(result_json_path) 32 | coco_result_list.append(coco_result) 33 | 34 | # use file names as fiftyone name, create unique names if duplicate 35 | result_name_temp = Path(result_json_path).stem 36 | result_name = result_name_temp 37 | name_increment = 2 38 | while result_name in result_name_list: 39 | result_name = result_name_temp + "_" + str(name_increment) 40 | name_increment += 1 41 | result_name_list.append(result_name) 42 | 43 | dataset = create_fiftyone_dataset_from_coco_file(image_dir, dataset_json_path) 44 | 45 | # submit detections if coco result is given 46 | if result_json_paths: 47 | for result_name, coco_result in zip(result_name_list, coco_result_list): 48 | add_coco_labels(dataset, result_name, coco_result, coco_id_field="gt_coco_id") 49 | 50 | # visualize results 51 | session = fo.launch_app() # pyright: ignore[reportArgumentType] 52 | session.dataset = dataset 53 | 54 | # order by false positives if any coco result is given 55 | if result_json_paths: 56 | # Evaluate the predictions 57 | first_coco_result_name = result_name_list[0] 58 | _ = dataset.evaluate_detections( 59 | first_coco_result_name, 60 | gt_field="gt_detections", 61 | eval_key=f"{first_coco_result_name}_eval", 62 | iou=iou_thresh, 63 | compute_mAP=False, 64 | ) 65 | # Get the 10 most common classes in the dataset 66 | # counts = dataset.count_values("gt_detections.detections.label") 67 | # classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10] 68 | # Print a classification report for the top-10 classes 69 | # results.print_report(classes=classes_top10) 70 | # Load the view on which we ran the `eval` evaluation 71 | eval_view = dataset.load_evaluation_view(f"{first_coco_result_name}_eval") 72 | # Show samples with most false positives 73 | session.view = eval_view.sort_by(f"{first_coco_result_name}_eval_fp", reverse=True) 74 | 75 | print(f"SAHI has successfully launched a Fiftyone app at http://localhost:{fo.config.default_app_port}") 76 | while 1: 77 | time.sleep(3) 78 | 79 | 80 | if __name__ == "__main__": 81 | fire.Fire(main) 82 | -------------------------------------------------------------------------------- /sahi/scripts/coco2yolo.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import fire 4 | 5 | from sahi.utils.coco import Coco 6 | from sahi.utils.file import Path, increment_path 7 | 8 | 9 | def main( 10 | image_dir: str, 11 | dataset_json_path: str, 12 | train_split: Union[int, float] = 0.9, 13 | project: str = "runs/coco2yolo", 14 | name: str = "exp", 15 | seed: int = 1, 16 | disable_symlink=False, 17 | ): 18 | """ 19 | Args: 20 | images_dir (str): directory for coco images 21 | dataset_json_path (str): file path for the coco json file to be converted 22 | train_split (float or int): set the training split ratio 23 | project (str): save results to project/name 24 | name (str): save results to project/name" 25 | seed (int): fix the seed for reproducibility 26 | disable_symlink (bool): required in google colab env 27 | """ 28 | 29 | # increment run 30 | save_dir = Path(increment_path(Path(project) / name, exist_ok=False)) 31 | # load coco dict 32 | coco = Coco.from_coco_dict_or_path( 33 | coco_dict_or_path=dataset_json_path, 34 | image_dir=image_dir, 35 | ) 36 | # export as YOLO 37 | coco.export_as_yolo( 38 | output_dir=str(save_dir), 39 | train_split_rate=train_split, 40 | numpy_seed=seed, 41 | disable_symlink=disable_symlink, 42 | ) 43 | 44 | print(f"COCO to YOLO conversion results are successfully exported to {save_dir}") 45 | 46 | 47 | if __name__ == "__main__": 48 | fire.Fire(main) 49 | -------------------------------------------------------------------------------- /sahi/scripts/predict.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from sahi.predict import predict 4 | 5 | 6 | def main(): 7 | fire.Fire(predict) 8 | 9 | 10 | if __name__ == "__main__": 11 | main() 12 | -------------------------------------------------------------------------------- /sahi/scripts/predict_fiftyone.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from sahi.predict import predict_fiftyone 4 | 5 | 6 | def main(): 7 | fire.Fire(predict_fiftyone) 8 | 9 | 10 | if __name__ == "__main__": 11 | main() 12 | -------------------------------------------------------------------------------- /sahi/scripts/slice_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fire 4 | 5 | from sahi.slicing import slice_coco 6 | from sahi.utils.file import Path, save_json 7 | 8 | 9 | def slice( 10 | image_dir: str, 11 | dataset_json_path: str, 12 | slice_size: int = 512, 13 | overlap_ratio: float = 0.2, 14 | ignore_negative_samples: bool = False, 15 | output_dir: str = "runs/slice_coco", 16 | min_area_ratio: float = 0.1, 17 | ): 18 | """ 19 | Args: 20 | image_dir (str): directory for coco images 21 | dataset_json_path (str): file path for the coco dataset json file 22 | slice_size (int) 23 | overlap_ratio (float): slice overlap ratio 24 | ignore_negative_samples (bool): ignore images without annotation 25 | output_dir (str): output export dir 26 | min_area_ratio (float): If the cropped annotation area to original 27 | annotation ratio is smaller than this value, the annotation 28 | is filtered out. Default 0.1. 29 | """ 30 | 31 | # assure slice_size is list 32 | slice_size_list = slice_size 33 | if isinstance(slice_size_list, (int, float)): 34 | slice_size_list = [slice_size_list] 35 | 36 | # slice coco dataset images and annotations 37 | print("Slicing step is starting...") 38 | for slice_size in slice_size_list: 39 | # in format: train_images_512_01 40 | output_images_folder_name = ( 41 | Path(dataset_json_path).stem + f"_images_{str(slice_size)}_{str(overlap_ratio).replace('.', '')}" 42 | ) 43 | output_images_dir = str(Path(output_dir) / output_images_folder_name) 44 | sliced_coco_name = Path(dataset_json_path).name.replace( 45 | ".json", f"_{str(slice_size)}_{str(overlap_ratio).replace('.', '')}" 46 | ) 47 | coco_dict, coco_path = slice_coco( 48 | coco_annotation_file_path=dataset_json_path, 49 | image_dir=image_dir, 50 | output_coco_annotation_file_name="", 51 | output_dir=output_images_dir, 52 | ignore_negative_samples=ignore_negative_samples, 53 | slice_height=slice_size, 54 | slice_width=slice_size, 55 | min_area_ratio=min_area_ratio, 56 | overlap_height_ratio=overlap_ratio, 57 | overlap_width_ratio=overlap_ratio, 58 | out_ext=".jpg", 59 | verbose=False, 60 | ) 61 | output_coco_annotation_file_path = os.path.join(output_dir, sliced_coco_name + ".json") 62 | save_json(coco_dict, output_coco_annotation_file_path) 63 | print(f"Sliced dataset for 'slice_size: {slice_size}' is exported to {output_dir}") 64 | 65 | 66 | if __name__ == "__main__": 67 | fire.Fire(slice) 68 | -------------------------------------------------------------------------------- /sahi/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/sahi/utils/__init__.py -------------------------------------------------------------------------------- /sahi/utils/compatibility.py: -------------------------------------------------------------------------------- 1 | def fix_shift_amount_list(shift_amount_list): 2 | # compatilibty for sahi v0.8.15 3 | if isinstance(shift_amount_list[0], (int, float)): 4 | shift_amount_list = [shift_amount_list] 5 | return shift_amount_list 6 | 7 | 8 | def fix_full_shape_list(full_shape_list): 9 | # compatilibty for sahi v0.8.15 10 | if full_shape_list is not None and isinstance(full_shape_list[0], (int, float)): 11 | full_shape_list = [full_shape_list] 12 | return full_shape_list 13 | -------------------------------------------------------------------------------- /sahi/utils/detectron2.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | class Detectron2TestConstants: 5 | FASTERCNN_MODEL_ZOO_NAME = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" 6 | RETINANET_MODEL_ZOO_NAME = "COCO-Detection/retinanet_R_50_FPN_3x.yaml" 7 | MASKRCNN_MODEL_ZOO_NAME = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" 8 | 9 | 10 | def export_cfg_as_yaml(cfg, export_path: str = "config.yaml"): 11 | """ 12 | Exports Detectron2 config object in yaml format so that it can be used later. 13 | Args: 14 | cfg (detectron2.config.CfgNode): Detectron2 config object. 15 | export_path (str): Path to export the Detectron2 config. 16 | Related Detectron2 doc: https://detectron2.readthedocs.io/en/stable/modules/config.html#detectron2.config.CfgNode.dump 17 | """ 18 | Path(export_path).parent.mkdir(exist_ok=True, parents=True) 19 | 20 | with open(export_path, "w") as f: 21 | f.write(cfg.dump()) 22 | -------------------------------------------------------------------------------- /sahi/utils/fiftyone.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | 5 | from sahi.utils.import_utils import is_available 6 | 7 | if is_available("fiftyone"): 8 | # to fix https://github.com/voxel51/fiftyone/issues/845 9 | if sys.platform == "win32": 10 | _ = subprocess.run("tskill mongod", stderr=subprocess.DEVNULL) 11 | else: 12 | _ = subprocess.run(["pkill", "mongod"], stderr=subprocess.DEVNULL) 13 | 14 | # import fo utilities 15 | import fiftyone as fo 16 | from fiftyone.utils.coco import COCODetectionDatasetImporter as BaseCOCODetectionDatasetImporter 17 | from fiftyone.utils.coco import _get_matching_image_ids, load_coco_detection_annotations 18 | 19 | class COCODetectionDatasetImporter(BaseCOCODetectionDatasetImporter): 20 | def setup(self): 21 | if self.labels_path is not None and os.path.isfile(self.labels_path): 22 | ( 23 | info, 24 | classes, 25 | supercategory_map, 26 | images, 27 | annotations, 28 | ) = load_coco_detection_annotations(self.labels_path, extra_attrs=self.extra_attrs) 29 | 30 | if classes is not None: 31 | info["classes"] = classes 32 | 33 | image_ids = _get_matching_image_ids( 34 | classes, 35 | images, 36 | annotations, 37 | image_ids=self.image_ids, 38 | classes=self.classes, 39 | shuffle=self.shuffle, 40 | seed=self.seed, 41 | max_samples=self.max_samples, 42 | ) 43 | 44 | filenames = [images[_id]["file_name"] for _id in image_ids] 45 | 46 | _image_ids = set(image_ids) 47 | image_dicts_map = {i["file_name"]: i for _id, i in images.items() if _id in _image_ids} 48 | else: 49 | info = {} 50 | classes = None 51 | supercategory_map = None 52 | image_dicts_map = {} 53 | annotations = None 54 | filenames = [] 55 | 56 | self._image_paths_map = { 57 | image["file_name"]: os.path.join(self.data_path, image["file_name"]) for image in images.values() 58 | } 59 | 60 | self._info = info 61 | self._classes = classes 62 | self._supercategory_map = supercategory_map 63 | self._image_dicts_map = image_dicts_map 64 | self._annotations = annotations 65 | self._filenames = filenames 66 | 67 | def create_fiftyone_dataset_from_coco_file(coco_image_dir: str, coco_json_path: str): 68 | coco_importer = COCODetectionDatasetImporter( 69 | data_path=coco_image_dir, labels_path=coco_json_path, include_id=True 70 | ) 71 | dataset = fo.Dataset.from_importer(coco_importer, label_field="gt") 72 | return dataset 73 | 74 | def launch_fiftyone_app(coco_image_dir: str, coco_json_path: str): 75 | dataset = create_fiftyone_dataset_from_coco_file(coco_image_dir, coco_json_path) 76 | session = fo.launch_app() 77 | session.dataset = dataset 78 | return session 79 | -------------------------------------------------------------------------------- /sahi/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | class HuggingfaceTestConstants: 2 | YOLOS_TINY_MODEL_PATH = "hustvl/yolos-tiny" 3 | RTDETRV2_MODEL_PATH = "PekingU/rtdetr_v2_r18vd" 4 | -------------------------------------------------------------------------------- /sahi/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import logging 3 | import os 4 | 5 | # adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py 6 | 7 | logger = logging.getLogger(__name__) 8 | logging.basicConfig( 9 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 10 | datefmt="%m/%d/%Y %H:%M:%S", 11 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 12 | ) 13 | 14 | 15 | def get_package_info(package_name: str, verbose: bool = True): 16 | """ 17 | Returns the package version as a string and the package name as a string. 18 | """ 19 | _is_available = is_available(package_name) 20 | 21 | if _is_available: 22 | try: 23 | import importlib.metadata as _importlib_metadata 24 | 25 | _version = _importlib_metadata.version(package_name) 26 | except (ModuleNotFoundError, AttributeError): 27 | try: 28 | _version = importlib.import_module(package_name).__version__ 29 | except AttributeError: 30 | _version = "unknown" 31 | if verbose: 32 | logger.info(f"{package_name} version {_version} is available.") 33 | else: 34 | _version = "N/A" 35 | 36 | return _is_available, _version 37 | 38 | 39 | def print_environment_info(): 40 | _torch_available, _torch_version = get_package_info("torch") 41 | _torchvision_available, _torchvision_version = get_package_info("torchvision") 42 | _tensorflow_available, _tensorflow_version = get_package_info("tensorflow") 43 | _tensorflow_hub_available, _tensorflow_hub_version = get_package_info("tensorflow-hub") 44 | _yolov5_available, _yolov5_version = get_package_info("yolov5") 45 | _mmdet_available, _mmdet_version = get_package_info("mmdet") 46 | _mmcv_available, _mmcv_version = get_package_info("mmcv") 47 | _detectron2_available, _detectron2_version = get_package_info("detectron2") 48 | _transformers_available, _transformers_version = get_package_info("transformers") 49 | _timm_available, _timm_version = get_package_info("timm") 50 | _fiftyone_available, _fiftyone_version = get_package_info("fiftyone") 51 | 52 | 53 | def is_available(module_name: str): 54 | return importlib.util.find_spec(module_name) is not None 55 | 56 | 57 | def check_requirements(package_names): 58 | """ 59 | Raise error if module is not installed. 60 | """ 61 | missing_packages = [] 62 | for package_name in package_names: 63 | if importlib.util.find_spec(package_name) is None: 64 | missing_packages.append(package_name) 65 | if missing_packages: 66 | raise ImportError(f"The following packages are required to use this module: {missing_packages}") 67 | yield 68 | 69 | 70 | def check_package_minimum_version(package_name: str, minimum_version: str, verbose=False): 71 | """ 72 | Raise error if module version is not compatible. 73 | """ 74 | from packaging import version 75 | 76 | _is_available, _version = get_package_info(package_name, verbose=verbose) 77 | if _is_available: 78 | if _version == "unknown": 79 | logger.warning( 80 | f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible." 81 | ) 82 | else: 83 | if version.parse(_version) < version.parse(minimum_version): 84 | return False 85 | return True 86 | 87 | 88 | def ensure_package_minimum_version(package_name: str, minimum_version: str, verbose=False): 89 | """ 90 | Raise error if module version is not compatible. 91 | """ 92 | from packaging import version 93 | 94 | _is_available, _version = get_package_info(package_name, verbose=verbose) 95 | if _is_available: 96 | if _version == "unknown": 97 | logger.warning( 98 | f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible." 99 | ) 100 | else: 101 | if version.parse(_version) < version.parse(minimum_version): 102 | raise ImportError( 103 | f"Please upgrade {package_name} to version {minimum_version} or higher to use this module." 104 | ) 105 | yield 106 | -------------------------------------------------------------------------------- /sahi/utils/rtdetr.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | from os import path 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | 7 | class RTDETRTestConstants: 8 | RTDETRL_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/rtdetr-l.pt" 9 | RTDETRL_MODEL_PATH = "tests/data/models/rtdetr/rtdetr-l.pt" 10 | 11 | RTDETRX_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/rtdetr-x.pt" 12 | RTDETRX_MODEL_PATH = "tests/data/models/rtdetr/rtdetr-x.pt" 13 | 14 | 15 | def download_rtdetrl_model(destination_path: Optional[str] = None): 16 | if destination_path is None: 17 | destination_path = RTDETRTestConstants.RTDETRL_MODEL_PATH 18 | 19 | Path(destination_path).parent.mkdir(parents=True, exist_ok=True) 20 | 21 | if not path.exists(destination_path): 22 | urllib.request.urlretrieve( 23 | RTDETRTestConstants.RTDETRX_MODEL_URL, 24 | destination_path, 25 | ) 26 | 27 | 28 | def download_rtdetrx_model(destination_path: Optional[str] = None): 29 | if destination_path is None: 30 | destination_path = RTDETRTestConstants.RTDETRX_MODEL_PATH 31 | 32 | Path(destination_path).parent.mkdir(parents=True, exist_ok=True) 33 | 34 | if not path.exists(destination_path): 35 | urllib.request.urlretrieve( 36 | RTDETRTestConstants.RTDETRX_MODEL_URL, 37 | destination_path, 38 | ) 39 | -------------------------------------------------------------------------------- /sahi/utils/sparseyolov5.py: -------------------------------------------------------------------------------- 1 | class Yolov5TestConstants: 2 | YOLOV_MODEL_URL = "zoo:cv/detection/yolov5-s/pytorch/ultralytics/coco/pruned-aggressive_96" 3 | -------------------------------------------------------------------------------- /sahi/utils/torch.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Fatih C Akyon, 2020. 3 | 4 | 5 | import os 6 | from typing import Any, Optional, Union 7 | 8 | import numpy as np 9 | from PIL.Image import Image 10 | from torch import Tensor, device 11 | 12 | try: 13 | import torch 14 | 15 | has_torch_cuda = torch.cuda.is_available() 16 | try: 17 | has_torch_mps: bool = torch.backends.mps.is_available() # pyright: ignore[reportAttributeAccessIssue] 18 | except Exception: 19 | has_torch_mps = False 20 | has_torch = True 21 | except ImportError: 22 | has_torch_cuda = False 23 | has_torch_mps = False 24 | has_torch = False 25 | 26 | 27 | def empty_cuda_cache(): 28 | if has_torch_cuda: 29 | return torch.cuda.empty_cache() 30 | 31 | 32 | def to_float_tensor(img: Union[np.ndarray, Image]) -> Tensor: 33 | """ 34 | Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range 35 | [0, 255] to a torch.FloatTensor of shape (C x H x W). 36 | Args: 37 | img: PIL.Image or numpy array 38 | Returns: 39 | torch.tensor 40 | """ 41 | nparray: np.ndarray 42 | if isinstance(img, np.ndarray): 43 | nparray = img 44 | else: 45 | nparray = np.array(img) 46 | nparray = nparray.transpose((2, 0, 1)) 47 | tens = torch.from_numpy(np.array(nparray)).float() 48 | if tens.max() > 1: 49 | tens /= 255 50 | return tens 51 | 52 | 53 | def torch_to_numpy(img: Any) -> np.ndarray: 54 | img = img.numpy() 55 | if img.max() > 1: 56 | img /= 255 57 | return img.transpose((1, 2, 0)) 58 | 59 | 60 | def select_device(device: Optional[str] = None) -> device: 61 | """ 62 | Selects torch device 63 | 64 | Args: 65 | device: "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc. 66 | When no device string is given, the order of preference 67 | to try is: cuda:0 > mps > cpu 68 | 69 | Returns: 70 | torch.device 71 | 72 | Inspired by https://github.com/ultralytics/yolov5/blob/6371de8879e7ad7ec5283e8b95cc6dd85d6a5e72/utils/torch_utils.py#L107 73 | """ 74 | if device == "cuda": 75 | device = "cuda:0" 76 | device = str(device).strip().lower().replace("cuda:", "").replace("none", "") # to string, 'cuda:0' to '0' 77 | cpu = device == "cpu" 78 | mps = device == "mps" # Apple Metal Performance Shaders (MPS) 79 | if cpu or mps: 80 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False 81 | elif device: # non-cpu device requested 82 | os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() 83 | 84 | if not cpu and not mps and has_torch_cuda: # prefer GPU if available 85 | arg = "cuda:0" 86 | elif mps and getattr(torch, "has_mps", False) and has_torch_mps: # prefer MPS if available 87 | arg = "mps" 88 | else: # revert to CPU 89 | arg = "cpu" 90 | 91 | return torch.device(arg) 92 | -------------------------------------------------------------------------------- /sahi/utils/torchvision.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Kadir Nar, 2022. 3 | 4 | 5 | from packaging import version 6 | 7 | from sahi.utils.import_utils import get_package_info 8 | 9 | 10 | class TorchVisionTestConstants: 11 | FASTERRCNN_CONFIG_PATH = "tests/data/models/torchvision/fasterrcnn_resnet50_fpn.yaml" 12 | SSD300_CONFIG_PATH = "tests/data/models/torchvision/ssd300_vgg16.yaml" 13 | 14 | 15 | _torchvision_available, _torchvision_version = get_package_info("torchvision", verbose=False) 16 | 17 | if _torchvision_available: 18 | import torchvision 19 | 20 | MODEL_NAME_TO_CONSTRUCTOR = { 21 | "fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn, 22 | "fasterrcnn_mobilenet_v3_large_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn, 23 | "fasterrcnn_mobilenet_v3_large_320_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn, 24 | "retinanet_resnet50_fpn": torchvision.models.detection.retinanet_resnet50_fpn, 25 | "ssd300_vgg16": torchvision.models.detection.ssd300_vgg16, 26 | "ssdlite320_mobilenet_v3_large": torchvision.models.detection.ssdlite320_mobilenet_v3_large, 27 | } 28 | 29 | # fcos requires torchvision >= 0.12.0 30 | if version.parse(_torchvision_version) >= version.parse("0.12.0"): 31 | MODEL_NAME_TO_CONSTRUCTOR["fcos_resnet50_fpn"] = (torchvision.models.detection.fcos_resnet50_fpn,) 32 | 33 | 34 | COCO_CLASSES = [ 35 | "__background__", 36 | "person", 37 | "bicycle", 38 | "car", 39 | "motorcycle", 40 | "airplane", 41 | "bus", 42 | "train", 43 | "truck", 44 | "boat", 45 | "traffic light", 46 | "fire hydrant", 47 | "N/A", 48 | "stop sign", 49 | "parking meter", 50 | "bench", 51 | "bird", 52 | "cat", 53 | "dog", 54 | "horse", 55 | "sheep", 56 | "cow", 57 | "elephant", 58 | "bear", 59 | "zebra", 60 | "giraffe", 61 | "N/A", 62 | "backpack", 63 | "umbrella", 64 | "N/A", 65 | "N/A", 66 | "handbag", 67 | "tie", 68 | "suitcase", 69 | "frisbee", 70 | "skis", 71 | "snowboard", 72 | "sports ball", 73 | "kite", 74 | "baseball bat", 75 | "baseball glove", 76 | "skateboard", 77 | "surfboard", 78 | "tennis racket", 79 | "bottle", 80 | "N/A", 81 | "wine glass", 82 | "cup", 83 | "fork", 84 | "knife", 85 | "spoon", 86 | "bowl", 87 | "banana", 88 | "apple", 89 | "sandwich", 90 | "orange", 91 | "broccoli", 92 | "carrot", 93 | "hot dog", 94 | "pizza", 95 | "donut", 96 | "cake", 97 | "chair", 98 | "couch", 99 | "potted plant", 100 | "bed", 101 | "N/A", 102 | "dining table", 103 | "N/A", 104 | "N/A", 105 | "toilet", 106 | "N/A", 107 | "tv", 108 | "laptop", 109 | "mouse", 110 | "remote", 111 | "keyboard", 112 | "cell phone", 113 | "microwave", 114 | "oven", 115 | "toaster", 116 | "sink", 117 | "refrigerator", 118 | "N/A", 119 | "book", 120 | "clock", 121 | "vase", 122 | "scissors", 123 | "teddy bear", 124 | "hair drier", 125 | "toothbrush", 126 | ] 127 | -------------------------------------------------------------------------------- /sahi/utils/ultralytics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | YOLOV8N_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt" 9 | YOLOV8N_SEG_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n-seg.pt" 10 | YOLO11N_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt" 11 | YOLO11N_SEG_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-seg.pt" 12 | YOLO11N_OBB_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-obb.pt" 13 | 14 | 15 | class UltralyticsTestConstants: 16 | YOLOV8N_MODEL_PATH = "tests/data/models/yolov8n.pt" 17 | YOLOV8N_SEG_MODEL_PATH = "tests/data/models/yolov8n-seg.pt" 18 | YOLO11N_MODEL_PATH = "tests/data/models/yolo11n.pt" 19 | YOLO11N_SEG_MODEL_PATH = "tests/data/models/yolo11n-seg.pt" 20 | YOLO11N_OBB_MODEL_PATH = "tests/data/models/yolo11n-obb.pt" 21 | 22 | 23 | def download_file(url: str, save_path: str, chunk_size: int = 8192) -> None: 24 | """ 25 | Downloads a file from a given URL to the specified path. 26 | 27 | Args: 28 | url: URL to download the file from 29 | save_path: Path where the file will be saved 30 | chunk_size: Size of chunks for downloading 31 | """ 32 | response = requests.get(url, stream=True) 33 | total_size = int(response.headers.get("content-length", 0)) 34 | 35 | # Ensure directory exists 36 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 37 | 38 | with open(save_path, "wb") as f, tqdm( 39 | desc=os.path.basename(save_path), 40 | total=total_size, 41 | unit="B", 42 | unit_scale=True, 43 | unit_divisor=1024, 44 | ) as pbar: 45 | for data in response.iter_content(chunk_size=chunk_size): 46 | size = f.write(data) 47 | pbar.update(size) 48 | 49 | 50 | def download_yolov8n_model(destination_path: Optional[str] = None) -> str: 51 | """Downloads YOLOv8n model if not already downloaded.""" 52 | if destination_path is None: 53 | destination_path = UltralyticsTestConstants.YOLOV8N_MODEL_PATH 54 | 55 | if not os.path.exists(destination_path): 56 | download_file(YOLOV8N_WEIGHTS_URL, destination_path) 57 | return destination_path 58 | 59 | 60 | def download_yolov8n_seg_model(destination_path: Optional[str] = None) -> str: 61 | """Downloads YOLOv8n-seg model if not already downloaded.""" 62 | if destination_path is None: 63 | destination_path = UltralyticsTestConstants.YOLOV8N_SEG_MODEL_PATH 64 | 65 | if not os.path.exists(destination_path): 66 | download_file(YOLOV8N_SEG_WEIGHTS_URL, destination_path) 67 | return destination_path 68 | 69 | 70 | def download_yolo11n_model(destination_path: Optional[str] = None) -> str: 71 | """Downloads YOLO11n model if not already downloaded.""" 72 | if destination_path is None: 73 | destination_path = UltralyticsTestConstants.YOLO11N_MODEL_PATH 74 | 75 | if not os.path.exists(destination_path): 76 | download_file(YOLO11N_WEIGHTS_URL, destination_path) 77 | return destination_path 78 | 79 | 80 | def download_yolo11n_seg_model(destination_path: Optional[str] = None) -> str: 81 | """Downloads YOLO11n-seg model if not already downloaded.""" 82 | if destination_path is None: 83 | destination_path = UltralyticsTestConstants.YOLO11N_SEG_MODEL_PATH 84 | 85 | if not os.path.exists(destination_path): 86 | download_file(YOLO11N_SEG_WEIGHTS_URL, destination_path) 87 | return destination_path 88 | 89 | 90 | def download_yolo11n_obb_model(destination_path: Optional[str] = None) -> str: 91 | """Downloads YOLO11n-obb model if not already downloaded.""" 92 | if destination_path is None: 93 | destination_path = UltralyticsTestConstants.YOLO11N_OBB_MODEL_PATH 94 | 95 | if not os.path.exists(destination_path): 96 | download_file(YOLO11N_OBB_WEIGHTS_URL, destination_path) 97 | return destination_path 98 | 99 | 100 | def download_model_weights(model_path: str) -> str: 101 | """ 102 | Downloads model weights based on the model path. 103 | 104 | Args: 105 | model_path: Path or name of the model 106 | Returns: 107 | Path to the downloaded weights file 108 | """ 109 | model_name = Path(model_path).stem 110 | if model_name == "yolov8n": 111 | return download_yolov8n_model() 112 | elif model_name == "yolov8n-seg": 113 | return download_yolov8n_seg_model() 114 | elif model_name == "yolo11n": 115 | return download_yolo11n_model() 116 | elif model_name == "yolo11n-seg": 117 | return download_yolo11n_seg_model() 118 | elif model_name == "yolo11n-obb": 119 | return download_yolo11n_obb_model() 120 | else: 121 | raise ValueError(f"Unknown model: {model_name}") 122 | -------------------------------------------------------------------------------- /sahi/utils/yolov5.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | from os import path 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | 7 | class Yolov5TestConstants: 8 | YOLOV5N_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n.pt" 9 | YOLOV5N_MODEL_PATH = "tests/data/models/yolov5/yolov5n.pt" 10 | 11 | YOLOV5S6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s6.pt" 12 | YOLOV5S6_MODEL_PATH = "tests/data/models/yolov5/yolov5s6.pt" 13 | 14 | YOLOV5M6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5m6.pt" 15 | YOLOV5M6_MODEL_PATH = "tests/data/models/yolov5/yolov5m6.pt" 16 | 17 | 18 | def download_yolov5n_model(destination_path: Optional[str] = None): 19 | if destination_path is None: 20 | destination_path = Yolov5TestConstants.YOLOV5N_MODEL_PATH 21 | 22 | Path(destination_path).parent.mkdir(parents=True, exist_ok=True) 23 | 24 | if not path.exists(destination_path): 25 | urllib.request.urlretrieve( 26 | Yolov5TestConstants.YOLOV5N_MODEL_URL, 27 | destination_path, 28 | ) 29 | 30 | 31 | def download_yolov5s6_model(destination_path: Optional[str] = None): 32 | if destination_path is None: 33 | destination_path = Yolov5TestConstants.YOLOV5S6_MODEL_PATH 34 | 35 | Path(destination_path).parent.mkdir(parents=True, exist_ok=True) 36 | 37 | if not path.exists(destination_path): 38 | urllib.request.urlretrieve( 39 | Yolov5TestConstants.YOLOV5S6_MODEL_URL, 40 | destination_path, 41 | ) 42 | -------------------------------------------------------------------------------- /sahi/utils/yolov8onnx.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | 6 | from sahi.utils.ultralytics import download_yolov8n_model 7 | 8 | 9 | # TODO: This class has no purpose, replace by just the constant 10 | class Yolov8ONNXTestConstants: 11 | YOLOV8N_ONNX_MODEL_PATH = "tests/data/models/yolov8/yolov8n.onnx" 12 | 13 | 14 | def download_yolov8n_onnx_model( 15 | destination_path: Union[str, Path] = Yolov8ONNXTestConstants.YOLOV8N_ONNX_MODEL_PATH, 16 | image_size: Optional[int] = 640, 17 | ): 18 | destination_path = Path(destination_path) 19 | model_path = destination_path.parent / (destination_path.stem + ".pt") 20 | download_yolov8n_model(str(model_path)) 21 | 22 | from ultralytics import YOLO 23 | 24 | model = YOLO(model_path) 25 | model.export(format="onnx") # , imgsz=image_size) 26 | 27 | 28 | def non_max_suppression(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> List[int]: 29 | """Perform non-max suppression. 30 | 31 | Args: 32 | boxes: np.ndarray 33 | Predicted bounding boxes, shape (num_of_boxes, 4) 34 | scores: np.ndarray 35 | Confidence for predicted bounding boxes, shape (num_of_boxes). 36 | iou_threshold: float 37 | Maximum allowed overlap between bounding boxes. 38 | 39 | Returns: 40 | list of box_ids of the kept bounding boxes 41 | """ 42 | # Sort by score 43 | sorted_indices = np.argsort(scores)[::-1] 44 | 45 | keep_boxes = [] 46 | while sorted_indices.size > 0: 47 | # Pick the last box 48 | box_id = sorted_indices[0] 49 | keep_boxes.append(box_id) 50 | 51 | # Compute IoU of the picked box with the rest 52 | ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) 53 | 54 | # Remove boxes with IoU over the threshold 55 | keep_indices = np.where(ious < iou_threshold)[0] 56 | 57 | # print(keep_indices.shape, sorted_indices.shape) 58 | sorted_indices = sorted_indices[keep_indices + 1] 59 | 60 | return keep_boxes 61 | 62 | 63 | def compute_iou(box: np.ndarray, boxes: np.ndarray) -> float: 64 | """Compute the IOU between a selected box and other boxes. 65 | 66 | Args: 67 | box: np.ndarray 68 | Selected box, shape (4) 69 | boxes: np.ndarray 70 | Other boxes used for computing IOU, shape (num_of_boxes, 4). 71 | 72 | Returns: 73 | float: intersection over union 74 | """ 75 | # Compute xmin, ymin, xmax, ymax for both boxes 76 | xmin = np.maximum(box[0], boxes[:, 0]) 77 | ymin = np.maximum(box[1], boxes[:, 1]) 78 | xmax = np.minimum(box[2], boxes[:, 2]) 79 | ymax = np.minimum(box[3], boxes[:, 3]) 80 | 81 | # Compute intersection area 82 | intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) 83 | 84 | # Compute union area 85 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 86 | boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 87 | union_area = box_area + boxes_area - intersection_area 88 | 89 | # Compute IoU 90 | iou = intersection_area / union_area 91 | 92 | return iou 93 | 94 | 95 | def xywh2xyxy(x: np.ndarray) -> np.ndarray: 96 | """Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2) 97 | 98 | Args: 99 | x: np.ndarray 100 | Input bboxes, shape (num_of_boxes, 4). 101 | 102 | Returns: 103 | np.ndarray: (num_of_boxes, 4) 104 | """ 105 | y = np.copy(x) 106 | y[..., 0] = x[..., 0] - x[..., 2] / 2 107 | y[..., 1] = x[..., 1] - x[..., 3] / 2 108 | y[..., 2] = x[..., 0] + x[..., 2] / 2 109 | y[..., 3] = x[..., 1] + x[..., 3] / 2 110 | return y 111 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/run_code_style.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | 4 | from scripts.utils import shell, validate_and_exit 5 | 6 | if __name__ == "__main__": 7 | arg = sys.argv[1] 8 | warnings.warn( 9 | "Please use 'ruff check' and 'ruff format' instead. Precede with 'uv run' to run in virtual environment. Remember to activate the pre-commit hook to do that automatically on every commit.", 10 | DeprecationWarning, 11 | ) 12 | 13 | if arg == "check": 14 | sts_flake = shell("flake8 . --config setup.cfg --select=E9,F63,F7,F82") 15 | sts_isort = shell("isort . --check --settings pyproject.toml") 16 | sts_black = shell("black . --check --config pyproject.toml") 17 | validate_and_exit(flake8=sts_flake, isort=sts_isort, black=sts_black) 18 | elif arg == "format": 19 | sts_isort = shell("isort . --settings pyproject.toml") 20 | sts_black = shell("black . --config pyproject.toml") 21 | validate_and_exit(isort=sts_isort, black=sts_black) 22 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | 5 | import click 6 | 7 | 8 | def shell(command, exit_status=0): 9 | """ 10 | Run command through shell and return exit status if exit status of command run match with given exit status. 11 | 12 | Args: 13 | command: (str) Command string which runs through system shell. 14 | exit_status: (int) Expected exit status of given command run. 15 | 16 | Returns: actual_exit_status 17 | 18 | """ 19 | actual_exit_status = os.system(command) 20 | if actual_exit_status == exit_status: 21 | return 0 22 | return actual_exit_status 23 | 24 | 25 | def validate_and_exit(expected_out_status=0, **kwargs): 26 | if all([arg == expected_out_status for arg in kwargs.values()]): 27 | # Expected status, OK 28 | sys.exit(0) 29 | else: 30 | # Failure 31 | print_console_centered("Summary Results") 32 | fail_count = 0 33 | for component, exit_status in kwargs.items(): 34 | if exit_status != expected_out_status: 35 | click.secho(f"{component} failed.", fg="red") 36 | fail_count += 1 37 | 38 | print_console_centered(f"{len(kwargs) - fail_count} success, {fail_count} failure") 39 | click.secho("\nTo fix formatting issues:", fg="yellow") 40 | click.secho("1. Install development dependencies:", fg="cyan") 41 | click.secho(' pip install -e ."[dev]"', fg="green") 42 | click.secho("\n2. Run code formatting:", fg="cyan") 43 | click.secho(" python -m scripts.run_code_style format", fg="green") 44 | sys.exit(1) 45 | 46 | 47 | def print_console_centered(text: str, fill_char="="): 48 | w, _ = shutil.get_terminal_size((80, 20)) 49 | print(f" {text} ".center(w, fill_char)) 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/__init__.py -------------------------------------------------------------------------------- /tests/check_commandline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################################################################## 4 | # Checks, if the CLI works 5 | ################################################################################## 6 | 7 | source .venv/bin/activate 8 | 9 | set -e 10 | 11 | # predict mmdet 12 | PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') 13 | 14 | echo "mmcv" 15 | if (($(echo "$PYTHON_VERSION < 3.11" | bc -l))); then 16 | echo "mmcv: 1" 17 | sahi predict --model_type mmdet --source tests/data/ --novisual --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320 18 | echo "mmcv: 2" 19 | sahi predict --model_type mmdet --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320 20 | echo "mmcv: 3" 21 | sahi predict --model_type mmdet --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/mmdet/yolox/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth --model_config_path tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py --image_size 320 22 | else 23 | echo "mmcv not available for Python version $PYTHON_VERSION" 24 | fi 25 | 26 | sahi predict --no_sliced_prediction --model_type yolov5 --source tests/data/coco_utils/terrain1.jpg --novisual --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320 27 | sahi predict --model_type yolov5 --source tests/data/ --novisual --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320 28 | sahi predict --model_type yolov5 --source tests/data/coco_utils/terrain1.jpg --export_pickle --export_crop --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320 29 | sahi predict --model_type yolov5 --source tests/data/coco_utils/ --novisual --dataset_json_path tests/data/coco_utils/combined_coco.json --model_path tests/data/models/yolov5/yolov5s6.pt --image_size 320 30 | # coco yolov5 31 | sahi coco yolov5 --image_dir tests/data/coco_utils/ --dataset_json_path tests/data/coco_utils/combined_coco.json --train_split 0.9 32 | # coco evaluate 33 | sahi coco evaluate --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json 34 | # coco analyse 35 | sahi coco analyse --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json --out_dir tests/data/coco_evaluate/ 36 | -------------------------------------------------------------------------------- /tests/check_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################################################################## 4 | # Checks, if the imports of the supported models are working as expected. 5 | # Not all frameworks are available on all python versions. 6 | # 7 | # It checks it for all python versions 8 | # Also, it checks, if the command line tool works on different examples 9 | ################################################################################## 10 | 11 | # This script should abort on error 12 | set -e 13 | 14 | # Python versions to check 15 | PYTHON_VERSIONS=("3.8 " "3.9 " "3.10" "3.11" "3.12") 16 | 17 | # Commands to check (okay means: return code 0) 18 | COMMANDS=( 19 | "uv run python -c 'from mmdet.apis.det_inferencer import DetInferencer'" 20 | "uv run python -c 'import torch'" 21 | "uv run python -c 'import ultralytics'" 22 | "uv run python -c 'import deepsparse'" 23 | "tests/check_commandline.sh" 24 | "uv run pytest -x" 25 | ) 26 | 27 | # Corresponding to the commands, the expected behaviour 28 | CONTEXTS=( 29 | "mmdet/mmcv with Python < 3.11" 30 | "torch, should work for all python versions" 31 | "ultralytics, should work for all python versions" 32 | "deepsparse, depends on onnxruntime, for Python <3.12" 33 | "command line" 34 | "pytest" 35 | ) 36 | 37 | GREEN='\033[0;32m' 38 | RED='\033[0;31m' 39 | NC='\033[0m' 40 | 41 | # Initialize an array to store results 42 | declare -A RESULTS 43 | 44 | # Loop over each Python version 45 | for version in "${PYTHON_VERSIONS[@]}"; do 46 | echo "Checking Python $version..." 47 | uv python pin "$version" 48 | uv sync -U 49 | # uv sync 50 | 51 | # Loop over each command 52 | for cmd in "${COMMANDS[@]}"; do 53 | echo -n "Checking $cmd..." 54 | # Check if the command runs without errors 55 | if eval "$cmd"; then 56 | RESULTS["$version $cmd"]="${GREEN}✅ okay${NC}" 57 | else 58 | RESULTS["$version $cmd"]="${RED}❌ not working${NC}" 59 | fi 60 | echo -e "${RESULTS["$version $cmd"]}" 61 | done 62 | done 63 | 64 | # Display the results 65 | for index in "${!COMMANDS[@]}"; do 66 | cmd="${COMMANDS[$index]}" 67 | context="${CONTEXTS[$index]}" 68 | 69 | echo -e "\n$context:" 70 | for version in "${PYTHON_VERSIONS[@]}"; do 71 | echo -e "$version : ${RESULTS["$version $cmd"]}" 72 | done 73 | done 74 | -------------------------------------------------------------------------------- /tests/data/coco_utils/coco_class_names.yaml: -------------------------------------------------------------------------------- 1 | - person 2 | - bicycle 3 | - car 4 | - motorcycle 5 | - airplane 6 | - bus 7 | - train 8 | - truck 9 | - boat 10 | - traffic light 11 | - fire hydrant 12 | - stop sign 13 | - parking meter 14 | - bench 15 | - bird 16 | - cat 17 | - dog 18 | - horse 19 | - sheep 20 | - cow 21 | - elephant 22 | - bear 23 | - zebra 24 | - giraffe 25 | - backpack 26 | - umbrella 27 | - handbag 28 | - tie 29 | - suitcase 30 | - frisbee 31 | - skis 32 | - snowboard 33 | - sports ball 34 | - kite 35 | - baseball bat 36 | - baseball glove 37 | - skateboard 38 | - surfboard 39 | - tennis racket 40 | - bottle 41 | - wine glass 42 | - cup 43 | - fork 44 | - knife 45 | - spoon 46 | - bowl 47 | - banana 48 | - apple 49 | - sandwich 50 | - orange 51 | - broccoli 52 | - carrot 53 | - hot dog 54 | - pizza 55 | - donut 56 | - cake 57 | - chair 58 | - couch 59 | - potted plant 60 | - bed 61 | - dining table 62 | - toilet 63 | - tv 64 | - laptop 65 | - mouse 66 | - remote 67 | - keyboard 68 | - cell phone 69 | - microwave 70 | - oven 71 | - toaster 72 | - sink 73 | - refrigerator 74 | - book 75 | - clock 76 | - vase 77 | - scissors 78 | - teddy bear 79 | - hair drier 80 | - toothbrush -------------------------------------------------------------------------------- /tests/data/coco_utils/modified_terrain1_coco.json: -------------------------------------------------------------------------------- 1 | { 2 | "images": [ 3 | { 4 | "height": 1365, 5 | "width": 2048, 6 | "id": 1, 7 | "file_name": "terrain1.jpg" 8 | } 9 | ], 10 | "annotations": [ 11 | { 12 | "iscrowd": 0, 13 | "image_id": 1, 14 | "bbox": [ 15 | 491.0, 16 | 1035.0, 17 | 153.0, 18 | 182.0 19 | ], 20 | "segmentation": [ 21 | [ 22 | 491.0, 23 | 1035.0, 24 | 644.0, 25 | 1035.0, 26 | 644.0, 27 | 1217.0, 28 | 491.0, 29 | 1217.0 30 | ] 31 | ], 32 | "category_id": 1, 33 | "id": 1, 34 | "area": 2795520 35 | }, 36 | { 37 | "iscrowd": 0, 38 | "image_id": 1, 39 | "bbox": [ 40 | 648.0, 41 | 1057.0, 42 | 73.0, 43 | 171.0 44 | ], 45 | "segmentation": [ 46 | [ 47 | 648.0, 48 | 1057.0, 49 | 721.0, 50 | 1057.0, 51 | 721.0, 52 | 1228.0, 53 | 648.0, 54 | 1228.0 55 | ] 56 | ], 57 | "category_id": 1, 58 | "id": 2, 59 | "area": 2795520 60 | }, 61 | { 62 | "iscrowd": 0, 63 | "image_id": 1, 64 | "bbox": [ 65 | 852.0, 66 | 1053.0, 67 | 69.0, 68 | 160.0 69 | ], 70 | "segmentation": [ 71 | [ 72 | 852.0, 73 | 1053.0, 74 | 921.0, 75 | 1053.0, 76 | 921.0, 77 | 1213.0, 78 | 852.0, 79 | 1213.0 80 | ] 81 | ], 82 | "category_id": 1, 83 | "id": 3, 84 | "area": 2795520 85 | }, 86 | { 87 | "iscrowd": 0, 88 | "image_id": 1, 89 | "bbox": [ 90 | 900.0, 91 | 1032.0, 92 | 57.0, 93 | 166.0 94 | ], 95 | "segmentation": [ 96 | [ 97 | 900.0, 98 | 1032.0, 99 | 957.0, 100 | 1032.0, 101 | 957.0, 102 | 1198.0, 103 | 900.0, 104 | 1198.0 105 | ] 106 | ], 107 | "category_id": 1, 108 | "id": 4, 109 | "area": 2795520 110 | }, 111 | { 112 | "iscrowd": 0, 113 | "image_id": 1, 114 | "bbox": [ 115 | 941.0, 116 | 1039.0, 117 | 48.0, 118 | 152.0 119 | ], 120 | "segmentation": [ 121 | [ 122 | 941.0, 123 | 1039.0, 124 | 989.0, 125 | 1039.0, 126 | 989.0, 127 | 1191.0, 128 | 941.0, 129 | 1191.0 130 | ] 131 | ], 132 | "category_id": 1, 133 | "id": 5, 134 | "area": 2795520 135 | }, 136 | { 137 | "iscrowd": 0, 138 | "image_id": 1, 139 | "bbox": [ 140 | 1190.0, 141 | 1048.0, 142 | 64.0, 143 | 129.0 144 | ], 145 | "segmentation": [ 146 | [ 147 | 1190.0, 148 | 1048.0, 149 | 1254.0, 150 | 1048.0, 151 | 1254.0, 152 | 1177.0, 153 | 1190.0, 154 | 1177.0 155 | ] 156 | ], 157 | "category_id": 1, 158 | "id": 6, 159 | "area": 2795520 160 | }, 161 | { 162 | "iscrowd": 0, 163 | "image_id": 1, 164 | "bbox": [ 165 | 1323.0, 166 | 1048.0, 167 | 95.0, 168 | 137.0 169 | ], 170 | "segmentation": [ 171 | [ 172 | 1323.0, 173 | 1048.0, 174 | 1418.0, 175 | 1048.0, 176 | 1418.0, 177 | 1185.0, 178 | 1323.0, 179 | 1185.0 180 | ] 181 | ], 182 | "category_id": 1, 183 | "id": 7, 184 | "area": 2795520 185 | } 186 | ], 187 | "categories": [ 188 | { 189 | "name": "human", 190 | "supercategory": "human", 191 | "id": 1 192 | }, 193 | { 194 | "name": "car", 195 | "supercategory": "car", 196 | "id": 2 197 | }, 198 | { 199 | "name": "big_vehicle", 200 | "supercategory": "big_vehicle", 201 | "id": 3 202 | } 203 | ] 204 | } -------------------------------------------------------------------------------- /tests/data/coco_utils/modified_terrain2_coco.json: -------------------------------------------------------------------------------- 1 | { 2 | "images": [ 3 | { 4 | "height": 682, 5 | "width": 1024, 6 | "id": 1, 7 | "file_name": "terrain2.png" 8 | } 9 | ], 10 | "annotations": [ 11 | { 12 | "iscrowd": 0, 13 | "image_id": 1, 14 | "bbox": [ 15 | 218.0, 16 | 448.0, 17 | 222.0, 18 | 161.0 19 | ], 20 | "segmentation": [ 21 | [ 22 | 218.0, 23 | 448.0, 24 | 440.0, 25 | 448.0, 26 | 440.0, 27 | 609.0, 28 | 218.0, 29 | 609.0 30 | ] 31 | ], 32 | "category_id": 2, 33 | "id": 1, 34 | "area": 698368 35 | }, 36 | { 37 | "iscrowd": 0, 38 | "image_id": 1, 39 | "bbox": [ 40 | 501.0, 41 | 451.0, 42 | 121.0, 43 | 92.0 44 | ], 45 | "segmentation": [ 46 | [ 47 | 501.0, 48 | 451.0, 49 | 622.0, 50 | 451.0, 51 | 622.0, 52 | 543.0, 53 | 501.0, 54 | 543.0 55 | ] 56 | ], 57 | "category_id": 2, 58 | "id": 2, 59 | "area": 698368 60 | }, 61 | { 62 | "iscrowd": 0, 63 | "image_id": 1, 64 | "bbox": [ 65 | 634.0, 66 | 437.0, 67 | 81.0, 68 | 56.0 69 | ], 70 | "segmentation": [ 71 | [ 72 | 634.0, 73 | 437.0, 74 | 715.0, 75 | 437.0, 76 | 715.0, 77 | 493.0, 78 | 634.0, 79 | 493.0 80 | ] 81 | ], 82 | "category_id": 2, 83 | "id": 3, 84 | "area": 698368 85 | }, 86 | { 87 | "iscrowd": 0, 88 | "image_id": 1, 89 | "bbox": [ 90 | 725.0, 91 | 423.0, 92 | 70.0, 93 | 51.0 94 | ], 95 | "segmentation": [ 96 | [ 97 | 725.0, 98 | 423.0, 99 | 795.0, 100 | 423.0, 101 | 795.0, 102 | 474.0, 103 | 725.0, 104 | 474.0 105 | ] 106 | ], 107 | "category_id": 2, 108 | "id": 4, 109 | "area": 698368 110 | }, 111 | { 112 | "iscrowd": 0, 113 | "image_id": 1, 114 | "bbox": [ 115 | 791.0, 116 | 404.0, 117 | 40.0, 118 | 47.0 119 | ], 120 | "segmentation": [ 121 | [ 122 | 791.0, 123 | 404.0, 124 | 831.0, 125 | 404.0, 126 | 831.0, 127 | 451.0, 128 | 791.0, 129 | 451.0 130 | ] 131 | ], 132 | "category_id": 2, 133 | "id": 5, 134 | "area": 698368 135 | } 136 | ], 137 | "categories": [ 138 | { 139 | "name": "human", 140 | "supercategory": "human", 141 | "id": 1 142 | }, 143 | { 144 | "name": "car", 145 | "supercategory": "car", 146 | "id": 2 147 | }, 148 | { 149 | "name": "big_vehicle", 150 | "supercategory": "big_vehicle", 151 | "id": 3 152 | } 153 | ] 154 | } -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/data/coco_utils/terrain1.jpg -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain1_coco.json: -------------------------------------------------------------------------------- 1 | { 2 | "images": [ 3 | { 4 | "height": 1365, 5 | "width": 2048, 6 | "id": 1, 7 | "file_name": "terrain1.jpg" 8 | } 9 | ], 10 | "categories": [ 11 | { 12 | "supercategory": "human", 13 | "id": 1, 14 | "name": "human" 15 | } 16 | ], 17 | "annotations": [ 18 | { 19 | "iscrowd": 0, 20 | "image_id": 1, 21 | "bbox": [ 22 | 491.0, 23 | 1035.0, 24 | 153.0, 25 | 182.0 26 | ], 27 | "segmentation": [ 28 | [ 29 | 491.0, 30 | 1035.0, 31 | 644.0, 32 | 1035.0, 33 | 644.0, 34 | 1217.0, 35 | 491.0, 36 | 1217.0 37 | ] 38 | ], 39 | "category_id": 1, 40 | "id": 1, 41 | "area": 2795520 42 | }, 43 | { 44 | "iscrowd": 0, 45 | "image_id": 1, 46 | "bbox": [ 47 | 648.0, 48 | 1057.0, 49 | 73.0, 50 | 171.0 51 | ], 52 | "segmentation": [ 53 | [ 54 | 648.0, 55 | 1057.0, 56 | 721.0, 57 | 1057.0, 58 | 721.0, 59 | 1228.0, 60 | 648.0, 61 | 1228.0 62 | ] 63 | ], 64 | "category_id": 1, 65 | "id": 2, 66 | "area": 2795520 67 | }, 68 | { 69 | "iscrowd": 0, 70 | "image_id": 1, 71 | "bbox": [ 72 | 852.0, 73 | 1053.0, 74 | 69.0, 75 | 160.0 76 | ], 77 | "segmentation": [ 78 | [ 79 | 852.0, 80 | 1053.0, 81 | 921.0, 82 | 1053.0, 83 | 921.0, 84 | 1213.0, 85 | 852.0, 86 | 1213.0 87 | ] 88 | ], 89 | "category_id": 1, 90 | "id": 3, 91 | "area": 2795520 92 | }, 93 | { 94 | "iscrowd": 0, 95 | "image_id": 1, 96 | "bbox": [ 97 | 900.0, 98 | 1032.0, 99 | 57.0, 100 | 166.0 101 | ], 102 | "segmentation": [ 103 | [ 104 | 900.0, 105 | 1032.0, 106 | 957.0, 107 | 1032.0, 108 | 957.0, 109 | 1198.0, 110 | 900.0, 111 | 1198.0 112 | ] 113 | ], 114 | "category_id": 1, 115 | "id": 4, 116 | "area": 2795520 117 | }, 118 | { 119 | "iscrowd": 0, 120 | "image_id": 1, 121 | "bbox": [ 122 | 941.0, 123 | 1039.0, 124 | 48.0, 125 | 152.0 126 | ], 127 | "segmentation": [ 128 | [ 129 | 941.0, 130 | 1039.0, 131 | 989.0, 132 | 1039.0, 133 | 989.0, 134 | 1191.0, 135 | 941.0, 136 | 1191.0 137 | ] 138 | ], 139 | "category_id": 1, 140 | "id": 5, 141 | "area": 2795520 142 | }, 143 | { 144 | "iscrowd": 0, 145 | "image_id": 1, 146 | "bbox": [ 147 | 1190.0, 148 | 1048.0, 149 | 64.0, 150 | 129.0 151 | ], 152 | "segmentation": [ 153 | [ 154 | 1190.0, 155 | 1048.0, 156 | 1254.0, 157 | 1048.0, 158 | 1254.0, 159 | 1177.0, 160 | 1190.0, 161 | 1177.0 162 | ] 163 | ], 164 | "category_id": 1, 165 | "id": 6, 166 | "area": 2795520 167 | }, 168 | { 169 | "iscrowd": 0, 170 | "image_id": 1, 171 | "bbox": [ 172 | 1323.0, 173 | 1048.0, 174 | 95.0, 175 | 137.0 176 | ], 177 | "segmentation": [ 178 | [ 179 | 1323.0, 180 | 1048.0, 181 | 1418.0, 182 | 1048.0, 183 | 1418.0, 184 | 1185.0, 185 | 1323.0, 186 | 1185.0 187 | ] 188 | ], 189 | "category_id": 1, 190 | "id": 7, 191 | "area": 2795520 192 | } 193 | ] 194 | } -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/data/coco_utils/terrain2.png -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain2_coco.json: -------------------------------------------------------------------------------- 1 | { 2 | "images": [ 3 | { 4 | "height": 682, 5 | "width": 1024, 6 | "id": 1, 7 | "file_name": "terrain2.png" 8 | } 9 | ], 10 | "categories": [ 11 | { 12 | "supercategory": "car", 13 | "id": 1, 14 | "name": "car" 15 | } 16 | ], 17 | "annotations": [ 18 | { 19 | "iscrowd": 0, 20 | "image_id": 1, 21 | "bbox": [ 22 | 218.0, 23 | 448.0, 24 | 222.0, 25 | 161.0 26 | ], 27 | "segmentation": [ 28 | [ 29 | 218.0, 30 | 448.0, 31 | 440.0, 32 | 448.0, 33 | 440.0, 34 | 609.0, 35 | 218.0, 36 | 609.0 37 | ] 38 | ], 39 | "category_id": 1, 40 | "id": 1, 41 | "area": 698368 42 | }, 43 | { 44 | "iscrowd": 0, 45 | "image_id": 1, 46 | "bbox": [ 47 | 501.0, 48 | 451.0, 49 | 121.0, 50 | 92.0 51 | ], 52 | "segmentation": [ 53 | [ 54 | 501.0, 55 | 451.0, 56 | 622.0, 57 | 451.0, 58 | 622.0, 59 | 543.0, 60 | 501.0, 61 | 543.0 62 | ] 63 | ], 64 | "category_id": 1, 65 | "id": 2, 66 | "area": 698368 67 | }, 68 | { 69 | "iscrowd": 0, 70 | "image_id": 1, 71 | "bbox": [ 72 | 634.0, 73 | 437.0, 74 | 81.0, 75 | 56.0 76 | ], 77 | "segmentation": [ 78 | [ 79 | 634.0, 80 | 437.0, 81 | 715.0, 82 | 437.0, 83 | 715.0, 84 | 493.0, 85 | 634.0, 86 | 493.0 87 | ] 88 | ], 89 | "category_id": 1, 90 | "id": 3, 91 | "area": 698368 92 | }, 93 | { 94 | "iscrowd": 0, 95 | "image_id": 1, 96 | "bbox": [ 97 | 725.0, 98 | 423.0, 99 | 70.0, 100 | 51.0 101 | ], 102 | "segmentation": [ 103 | [ 104 | 725.0, 105 | 423.0, 106 | 795.0, 107 | 423.0, 108 | 795.0, 109 | 474.0, 110 | 725.0, 111 | 474.0 112 | ] 113 | ], 114 | "category_id": 1, 115 | "id": 4, 116 | "area": 698368 117 | }, 118 | { 119 | "iscrowd": 0, 120 | "image_id": 1, 121 | "bbox": [ 122 | 791.0, 123 | 404.0, 124 | 40.0, 125 | 47.0 126 | ], 127 | "segmentation": [ 128 | [ 129 | 791.0, 130 | 404.0, 131 | 831.0, 132 | 404.0, 133 | 831.0, 134 | 451.0, 135 | 791.0, 136 | 451.0 137 | ] 138 | ], 139 | "category_id": 1, 140 | "id": 5, 141 | "area": 698368 142 | } 143 | ] 144 | } 145 | -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain2_gray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/data/coco_utils/terrain2_gray.png -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/data/coco_utils/terrain3.png -------------------------------------------------------------------------------- /tests/data/coco_utils/terrain4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/data/coco_utils/terrain4.png -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/cityscapes_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CityscapesDataset" 3 | data_root = "data/cityscapes/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/segmentation/cityscapes/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/segmentation/', 16 | # 'data/': 's3://openmmlab/datasets/segmentation/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True), 23 | dict(type="RandomResize", scale=[(2048, 800), (2048, 1024)], keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | 28 | test_pipeline = [ 29 | dict(type="LoadImageFromFile", backend_args=backend_args), 30 | dict(type="Resize", scale=(2048, 1024), keep_ratio=True), 31 | # If you don't have a gt annotation, delete the pipeline 32 | dict(type="LoadAnnotations", with_bbox=True), 33 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 34 | ] 35 | 36 | train_dataloader = dict( 37 | batch_size=1, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type="DefaultSampler", shuffle=True), 41 | batch_sampler=dict(type="AspectRatioBatchSampler"), 42 | dataset=dict( 43 | type="RepeatDataset", 44 | times=8, 45 | dataset=dict( 46 | type=dataset_type, 47 | data_root=data_root, 48 | ann_file="annotations/instancesonly_filtered_gtFine_train.json", 49 | data_prefix=dict(img="leftImg8bit/train/"), 50 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 51 | pipeline=train_pipeline, 52 | backend_args=backend_args, 53 | ), 54 | ), 55 | ) 56 | 57 | val_dataloader = dict( 58 | batch_size=1, 59 | num_workers=2, 60 | persistent_workers=True, 61 | drop_last=False, 62 | sampler=dict(type="DefaultSampler", shuffle=False), 63 | dataset=dict( 64 | type=dataset_type, 65 | data_root=data_root, 66 | ann_file="annotations/instancesonly_filtered_gtFine_val.json", 67 | data_prefix=dict(img="leftImg8bit/val/"), 68 | test_mode=True, 69 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 70 | pipeline=test_pipeline, 71 | backend_args=backend_args, 72 | ), 73 | ) 74 | 75 | test_dataloader = val_dataloader 76 | 77 | val_evaluator = dict( 78 | type="CocoMetric", 79 | ann_file=data_root + "annotations/instancesonly_filtered_gtFine_val.json", 80 | metric="bbox", 81 | backend_args=backend_args, 82 | ) 83 | 84 | test_evaluator = val_evaluator 85 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/cityscapes_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CityscapesDataset" 3 | data_root = "data/cityscapes/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/segmentation/cityscapes/' 10 | 11 | # Method 2: Use backend_args, file_client_args in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/segmentation/', 16 | # 'data/': 's3://openmmlab/datasets/segmentation/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 23 | dict(type="RandomResize", scale=[(2048, 800), (2048, 1024)], keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | 28 | test_pipeline = [ 29 | dict(type="LoadImageFromFile", backend_args=backend_args), 30 | dict(type="Resize", scale=(2048, 1024), keep_ratio=True), 31 | # If you don't have a gt annotation, delete the pipeline 32 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 33 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 34 | ] 35 | 36 | train_dataloader = dict( 37 | batch_size=1, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type="DefaultSampler", shuffle=True), 41 | batch_sampler=dict(type="AspectRatioBatchSampler"), 42 | dataset=dict( 43 | type="RepeatDataset", 44 | times=8, 45 | dataset=dict( 46 | type=dataset_type, 47 | data_root=data_root, 48 | ann_file="annotations/instancesonly_filtered_gtFine_train.json", 49 | data_prefix=dict(img="leftImg8bit/train/"), 50 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 51 | pipeline=train_pipeline, 52 | backend_args=backend_args, 53 | ), 54 | ), 55 | ) 56 | 57 | val_dataloader = dict( 58 | batch_size=1, 59 | num_workers=2, 60 | persistent_workers=True, 61 | drop_last=False, 62 | sampler=dict(type="DefaultSampler", shuffle=False), 63 | dataset=dict( 64 | type=dataset_type, 65 | data_root=data_root, 66 | ann_file="annotations/instancesonly_filtered_gtFine_val.json", 67 | data_prefix=dict(img="leftImg8bit/val/"), 68 | test_mode=True, 69 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 70 | pipeline=test_pipeline, 71 | backend_args=backend_args, 72 | ), 73 | ) 74 | 75 | test_dataloader = val_dataloader 76 | 77 | val_evaluator = [ 78 | dict( 79 | type="CocoMetric", 80 | ann_file=data_root + "annotations/instancesonly_filtered_gtFine_val.json", 81 | metric=["bbox", "segm"], 82 | backend_args=backend_args, 83 | ), 84 | dict( 85 | type="CityScapesMetric", 86 | seg_prefix=data_root + "gtFine/val", 87 | outfile_prefix="./work_dirs/cityscapes_metric/instance", 88 | backend_args=backend_args, 89 | ), 90 | ] 91 | 92 | test_evaluator = val_evaluator 93 | 94 | # inference on test dataset and 95 | # format the output results for submission. 96 | # test_dataloader = dict( 97 | # batch_size=1, 98 | # num_workers=2, 99 | # persistent_workers=True, 100 | # drop_last=False, 101 | # sampler=dict(type='DefaultSampler', shuffle=False), 102 | # dataset=dict( 103 | # type=dataset_type, 104 | # data_root=data_root, 105 | # ann_file='annotations/instancesonly_filtered_gtFine_test.json', 106 | # data_prefix=dict(img='leftImg8bit/test/'), 107 | # test_mode=True, 108 | # filter_cfg=dict(filter_empty_gt=True, min_size=32), 109 | # pipeline=test_pipeline)) 110 | # test_evaluator = dict( 111 | # type='CityScapesMetric', 112 | # format_only=True, 113 | # outfile_prefix='./work_dirs/cityscapes_metric/test') 114 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/coco_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CocoDataset" 3 | data_root = "data/coco/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True), 23 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 30 | # If you don't have a gt annotation, delete the pipeline 31 | dict(type="LoadAnnotations", with_bbox=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type="DefaultSampler", shuffle=True), 39 | batch_sampler=dict(type="AspectRatioBatchSampler"), 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | ann_file="annotations/instances_train2017.json", 44 | data_prefix=dict(img="train2017/"), 45 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 46 | pipeline=train_pipeline, 47 | backend_args=backend_args, 48 | ), 49 | ) 50 | val_dataloader = dict( 51 | batch_size=1, 52 | num_workers=2, 53 | persistent_workers=True, 54 | drop_last=False, 55 | sampler=dict(type="DefaultSampler", shuffle=False), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file="annotations/instances_val2017.json", 60 | data_prefix=dict(img="val2017/"), 61 | test_mode=True, 62 | pipeline=test_pipeline, 63 | backend_args=backend_args, 64 | ), 65 | ) 66 | test_dataloader = val_dataloader 67 | 68 | val_evaluator = dict( 69 | type="CocoMetric", 70 | ann_file=data_root + "annotations/instances_val2017.json", 71 | metric="bbox", 72 | format_only=False, 73 | backend_args=backend_args, 74 | ) 75 | test_evaluator = val_evaluator 76 | 77 | # inference on test dataset and 78 | # format the output results for submission. 79 | # test_dataloader = dict( 80 | # batch_size=1, 81 | # num_workers=2, 82 | # persistent_workers=True, 83 | # drop_last=False, 84 | # sampler=dict(type='DefaultSampler', shuffle=False), 85 | # dataset=dict( 86 | # type=dataset_type, 87 | # data_root=data_root, 88 | # ann_file=data_root + 'annotations/image_info_test-dev2017.json', 89 | # data_prefix=dict(img='test2017/'), 90 | # test_mode=True, 91 | # pipeline=test_pipeline)) 92 | # test_evaluator = dict( 93 | # type='CocoMetric', 94 | # metric='bbox', 95 | # format_only=True, 96 | # ann_file=data_root + 'annotations/image_info_test-dev2017.json', 97 | # outfile_prefix='./work_dirs/coco_detection/test') 98 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/coco_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CocoDataset" 3 | data_root = "data/coco/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 23 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 30 | # If you don't have a gt annotation, delete the pipeline 31 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type="DefaultSampler", shuffle=True), 39 | batch_sampler=dict(type="AspectRatioBatchSampler"), 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | ann_file="annotations/instances_train2017.json", 44 | data_prefix=dict(img="train2017/"), 45 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 46 | pipeline=train_pipeline, 47 | backend_args=backend_args, 48 | ), 49 | ) 50 | val_dataloader = dict( 51 | batch_size=1, 52 | num_workers=2, 53 | persistent_workers=True, 54 | drop_last=False, 55 | sampler=dict(type="DefaultSampler", shuffle=False), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file="annotations/instances_val2017.json", 60 | data_prefix=dict(img="val2017/"), 61 | test_mode=True, 62 | pipeline=test_pipeline, 63 | backend_args=backend_args, 64 | ), 65 | ) 66 | test_dataloader = val_dataloader 67 | 68 | val_evaluator = dict( 69 | type="CocoMetric", 70 | ann_file=data_root + "annotations/instances_val2017.json", 71 | metric=["bbox", "segm"], 72 | format_only=False, 73 | backend_args=backend_args, 74 | ) 75 | test_evaluator = val_evaluator 76 | 77 | # inference on test dataset and 78 | # format the output results for submission. 79 | # test_dataloader = dict( 80 | # batch_size=1, 81 | # num_workers=2, 82 | # persistent_workers=True, 83 | # drop_last=False, 84 | # sampler=dict(type='DefaultSampler', shuffle=False), 85 | # dataset=dict( 86 | # type=dataset_type, 87 | # data_root=data_root, 88 | # ann_file=data_root + 'annotations/image_info_test-dev2017.json', 89 | # data_prefix=dict(img='test2017/'), 90 | # test_mode=True, 91 | # pipeline=test_pipeline)) 92 | # test_evaluator = dict( 93 | # type='CocoMetric', 94 | # metric=['bbox', 'segm'], 95 | # format_only=True, 96 | # ann_file=data_root + 'annotations/image_info_test-dev2017.json', 97 | # outfile_prefix='./work_dirs/coco_instance/test') 98 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/coco_instance_semantic.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CocoDataset" 3 | data_root = "data/coco/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True, with_seg=True), 23 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 30 | # If you don't have a gt annotation, delete the pipeline 31 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True, with_seg=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | 35 | train_dataloader = dict( 36 | batch_size=2, 37 | num_workers=2, 38 | persistent_workers=True, 39 | sampler=dict(type="DefaultSampler", shuffle=True), 40 | batch_sampler=dict(type="AspectRatioBatchSampler"), 41 | dataset=dict( 42 | type=dataset_type, 43 | data_root=data_root, 44 | ann_file="annotations/instances_train2017.json", 45 | data_prefix=dict(img="train2017/", seg="stuffthingmaps/train2017/"), 46 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 47 | pipeline=train_pipeline, 48 | backend_args=backend_args, 49 | ), 50 | ) 51 | 52 | val_dataloader = dict( 53 | batch_size=1, 54 | num_workers=2, 55 | persistent_workers=True, 56 | drop_last=False, 57 | sampler=dict(type="DefaultSampler", shuffle=False), 58 | dataset=dict( 59 | type=dataset_type, 60 | data_root=data_root, 61 | ann_file="annotations/instances_val2017.json", 62 | data_prefix=dict(img="val2017/"), 63 | test_mode=True, 64 | pipeline=test_pipeline, 65 | backend_args=backend_args, 66 | ), 67 | ) 68 | 69 | test_dataloader = val_dataloader 70 | 71 | val_evaluator = dict( 72 | type="CocoMetric", 73 | ann_file=data_root + "annotations/instances_val2017.json", 74 | metric=["bbox", "segm"], 75 | format_only=False, 76 | backend_args=backend_args, 77 | ) 78 | test_evaluator = val_evaluator 79 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CocoPanopticDataset" 3 | # data_root = 'data/coco/' 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | data_root = "s3://openmmlab/datasets/detection/coco/" 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadPanopticAnnotations", backend_args=backend_args), 23 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 30 | dict(type="LoadPanopticAnnotations", backend_args=backend_args), 31 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 32 | ] 33 | 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type="DefaultSampler", shuffle=True), 39 | batch_sampler=dict(type="AspectRatioBatchSampler"), 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | ann_file="annotations/panoptic_train2017.json", 44 | data_prefix=dict(img="train2017/", seg="annotations/panoptic_train2017/"), 45 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 46 | pipeline=train_pipeline, 47 | backend_args=backend_args, 48 | ), 49 | ) 50 | val_dataloader = dict( 51 | batch_size=1, 52 | num_workers=2, 53 | persistent_workers=True, 54 | drop_last=False, 55 | sampler=dict(type="DefaultSampler", shuffle=False), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file="annotations/panoptic_val2017.json", 60 | data_prefix=dict(img="val2017/", seg="annotations/panoptic_val2017/"), 61 | test_mode=True, 62 | pipeline=test_pipeline, 63 | backend_args=backend_args, 64 | ), 65 | ) 66 | test_dataloader = val_dataloader 67 | 68 | val_evaluator = dict( 69 | type="CocoPanopticMetric", 70 | ann_file=data_root + "annotations/panoptic_val2017.json", 71 | seg_prefix=data_root + "annotations/panoptic_val2017/", 72 | backend_args=backend_args, 73 | ) 74 | test_evaluator = val_evaluator 75 | 76 | # inference on test dataset and 77 | # format the output results for submission. 78 | # test_dataloader = dict( 79 | # batch_size=1, 80 | # num_workers=1, 81 | # persistent_workers=True, 82 | # drop_last=False, 83 | # sampler=dict(type='DefaultSampler', shuffle=False), 84 | # dataset=dict( 85 | # type=dataset_type, 86 | # data_root=data_root, 87 | # ann_file='annotations/panoptic_image_info_test-dev2017.json', 88 | # data_prefix=dict(img='test2017/'), 89 | # test_mode=True, 90 | # pipeline=test_pipeline)) 91 | # test_evaluator = dict( 92 | # type='CocoPanopticMetric', 93 | # format_only=True, 94 | # ann_file=data_root + 'annotations/panoptic_image_info_test-dev2017.json', 95 | # outfile_prefix='./work_dirs/coco_panoptic/test') 96 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/deepfashion.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "DeepFashionDataset" 3 | data_root = "data/DeepFashion/In-shop/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 23 | dict(type="Resize", scale=(750, 1101), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(750, 1101), keep_ratio=True), 30 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 31 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 32 | ] 33 | train_dataloader = dict( 34 | batch_size=2, 35 | num_workers=2, 36 | persistent_workers=True, 37 | sampler=dict(type="DefaultSampler", shuffle=True), 38 | batch_sampler=dict(type="AspectRatioBatchSampler"), 39 | dataset=dict( 40 | type="RepeatDataset", 41 | times=2, 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | ann_file="Anno/segmentation/DeepFashion_segmentation_train.json", 46 | data_prefix=dict(img="Img/"), 47 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 48 | pipeline=train_pipeline, 49 | backend_args=backend_args, 50 | ), 51 | ), 52 | ) 53 | val_dataloader = dict( 54 | batch_size=1, 55 | num_workers=2, 56 | persistent_workers=True, 57 | drop_last=False, 58 | sampler=dict(type="DefaultSampler", shuffle=False), 59 | dataset=dict( 60 | type=dataset_type, 61 | data_root=data_root, 62 | ann_file="Anno/segmentation/DeepFashion_segmentation_query.json", 63 | data_prefix=dict(img="Img/"), 64 | test_mode=True, 65 | pipeline=test_pipeline, 66 | backend_args=backend_args, 67 | ), 68 | ) 69 | test_dataloader = dict( 70 | batch_size=1, 71 | num_workers=2, 72 | persistent_workers=True, 73 | drop_last=False, 74 | sampler=dict(type="DefaultSampler", shuffle=False), 75 | dataset=dict( 76 | type=dataset_type, 77 | data_root=data_root, 78 | ann_file="Anno/segmentation/DeepFashion_segmentation_gallery.json", 79 | data_prefix=dict(img="Img/"), 80 | test_mode=True, 81 | pipeline=test_pipeline, 82 | backend_args=backend_args, 83 | ), 84 | ) 85 | 86 | val_evaluator = dict( 87 | type="CocoMetric", 88 | ann_file=data_root + "Anno/segmentation/DeepFashion_segmentation_query.json", 89 | metric=["bbox", "segm"], 90 | format_only=False, 91 | backend_args=backend_args, 92 | ) 93 | test_evaluator = dict( 94 | type="CocoMetric", 95 | ann_file=data_root + "Anno/segmentation/DeepFashion_segmentation_gallery.json", 96 | metric=["bbox", "segm"], 97 | format_only=False, 98 | backend_args=backend_args, 99 | ) 100 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/lvis_v0.5_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "LVISV05Dataset" 3 | data_root = "data/lvis_v0.5/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/lvis_v0.5/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 23 | dict( 24 | type="RandomChoiceResize", 25 | scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800)], 26 | keep_ratio=True, 27 | ), 28 | dict(type="RandomFlip", prob=0.5), 29 | dict(type="PackDetInputs"), 30 | ] 31 | test_pipeline = [ 32 | dict(type="LoadImageFromFile", backend_args=backend_args), 33 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 34 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 35 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 36 | ] 37 | 38 | train_dataloader = dict( 39 | batch_size=2, 40 | num_workers=2, 41 | persistent_workers=True, 42 | sampler=dict(type="DefaultSampler", shuffle=True), 43 | batch_sampler=dict(type="AspectRatioBatchSampler"), 44 | dataset=dict( 45 | type="ClassBalancedDataset", 46 | oversample_thr=1e-3, 47 | dataset=dict( 48 | type=dataset_type, 49 | data_root=data_root, 50 | ann_file="annotations/lvis_v0.5_train.json", 51 | data_prefix=dict(img="train2017/"), 52 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 53 | pipeline=train_pipeline, 54 | backend_args=backend_args, 55 | ), 56 | ), 57 | ) 58 | val_dataloader = dict( 59 | batch_size=1, 60 | num_workers=2, 61 | persistent_workers=True, 62 | drop_last=False, 63 | sampler=dict(type="DefaultSampler", shuffle=False), 64 | dataset=dict( 65 | type=dataset_type, 66 | data_root=data_root, 67 | ann_file="annotations/lvis_v0.5_val.json", 68 | data_prefix=dict(img="val2017/"), 69 | test_mode=True, 70 | pipeline=test_pipeline, 71 | backend_args=backend_args, 72 | ), 73 | ) 74 | test_dataloader = val_dataloader 75 | 76 | val_evaluator = dict( 77 | type="LVISMetric", 78 | ann_file=data_root + "annotations/lvis_v0.5_val.json", 79 | metric=["bbox", "segm"], 80 | backend_args=backend_args, 81 | ) 82 | test_evaluator = val_evaluator 83 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/lvis_v1_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | _base_ = "lvis_v0.5_instance.py" 3 | dataset_type = "LVISV1Dataset" 4 | data_root = "data/lvis_v1/" 5 | 6 | train_dataloader = dict( 7 | dataset=dict( 8 | dataset=dict( 9 | type=dataset_type, data_root=data_root, ann_file="annotations/lvis_v1_train.json", data_prefix=dict(img="") 10 | ) 11 | ) 12 | ) 13 | val_dataloader = dict( 14 | dataset=dict( 15 | type=dataset_type, data_root=data_root, ann_file="annotations/lvis_v1_val.json", data_prefix=dict(img="") 16 | ) 17 | ) 18 | test_dataloader = val_dataloader 19 | 20 | val_evaluator = dict(ann_file=data_root + "annotations/lvis_v1_val.json") 21 | test_evaluator = val_evaluator 22 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/objects365v1_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "Objects365V1Dataset" 3 | data_root = "data/Objects365/Obj365_v1/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True), 23 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 30 | # If you don't have a gt annotation, delete the pipeline 31 | dict(type="LoadAnnotations", with_bbox=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type="DefaultSampler", shuffle=True), 39 | batch_sampler=dict(type="AspectRatioBatchSampler"), 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | ann_file="annotations/objects365_train.json", 44 | data_prefix=dict(img="train/"), 45 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 46 | pipeline=train_pipeline, 47 | backend_args=backend_args, 48 | ), 49 | ) 50 | val_dataloader = dict( 51 | batch_size=1, 52 | num_workers=2, 53 | persistent_workers=True, 54 | drop_last=False, 55 | sampler=dict(type="DefaultSampler", shuffle=False), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file="annotations/objects365_val.json", 60 | data_prefix=dict(img="val/"), 61 | test_mode=True, 62 | pipeline=test_pipeline, 63 | backend_args=backend_args, 64 | ), 65 | ) 66 | test_dataloader = val_dataloader 67 | 68 | val_evaluator = dict( 69 | type="CocoMetric", 70 | ann_file=data_root + "annotations/objects365_val.json", 71 | metric="bbox", 72 | sort_categories=True, 73 | format_only=False, 74 | backend_args=backend_args, 75 | ) 76 | test_evaluator = val_evaluator 77 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/objects365v2_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "Objects365V2Dataset" 3 | data_root = "data/Objects365/Obj365_v2/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True), 23 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1333, 800), keep_ratio=True), 30 | # If you don't have a gt annotation, delete the pipeline 31 | dict(type="LoadAnnotations", with_bbox=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type="DefaultSampler", shuffle=True), 39 | batch_sampler=dict(type="AspectRatioBatchSampler"), 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | ann_file="annotations/zhiyuan_objv2_train.json", 44 | data_prefix=dict(img="train/"), 45 | filter_cfg=dict(filter_empty_gt=True, min_size=32), 46 | pipeline=train_pipeline, 47 | backend_args=backend_args, 48 | ), 49 | ) 50 | val_dataloader = dict( 51 | batch_size=1, 52 | num_workers=2, 53 | persistent_workers=True, 54 | drop_last=False, 55 | sampler=dict(type="DefaultSampler", shuffle=False), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file="annotations/zhiyuan_objv2_val.json", 60 | data_prefix=dict(img="val/"), 61 | test_mode=True, 62 | pipeline=test_pipeline, 63 | backend_args=backend_args, 64 | ), 65 | ) 66 | test_dataloader = val_dataloader 67 | 68 | val_evaluator = dict( 69 | type="CocoMetric", 70 | ann_file=data_root + "annotations/zhiyuan_objv2_val.json", 71 | metric="bbox", 72 | format_only=False, 73 | backend_args=backend_args, 74 | ) 75 | test_evaluator = val_evaluator 76 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/openimages_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "OpenImagesDataset" 3 | data_root = "data/OpenImages/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/coco/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/detection/', 16 | # 'data/': 's3://openmmlab/datasets/detection/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True), 23 | dict(type="Resize", scale=(1024, 800), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1024, 800), keep_ratio=True), 30 | # avoid bboxes being resized 31 | dict(type="LoadAnnotations", with_bbox=True), 32 | # TODO: find a better way to collect image_level_labels 33 | dict( 34 | type="PackDetInputs", 35 | meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor", "instances", "image_level_labels"), 36 | ), 37 | ] 38 | 39 | train_dataloader = dict( 40 | batch_size=2, 41 | num_workers=0, # workers_per_gpu > 0 may occur out of memory 42 | persistent_workers=False, 43 | sampler=dict(type="DefaultSampler", shuffle=True), 44 | batch_sampler=dict(type="AspectRatioBatchSampler"), 45 | dataset=dict( 46 | type=dataset_type, 47 | data_root=data_root, 48 | ann_file="annotations/oidv6-train-annotations-bbox.csv", 49 | data_prefix=dict(img="OpenImages/train/"), 50 | label_file="annotations/class-descriptions-boxable.csv", 51 | hierarchy_file="annotations/bbox_labels_600_hierarchy.json", 52 | meta_file="annotations/train-image-metas.pkl", 53 | pipeline=train_pipeline, 54 | backend_args=backend_args, 55 | ), 56 | ) 57 | val_dataloader = dict( 58 | batch_size=1, 59 | num_workers=0, 60 | persistent_workers=False, 61 | drop_last=False, 62 | sampler=dict(type="DefaultSampler", shuffle=False), 63 | dataset=dict( 64 | type=dataset_type, 65 | data_root=data_root, 66 | ann_file="annotations/validation-annotations-bbox.csv", 67 | data_prefix=dict(img="OpenImages/validation/"), 68 | label_file="annotations/class-descriptions-boxable.csv", 69 | hierarchy_file="annotations/bbox_labels_600_hierarchy.json", 70 | meta_file="annotations/validation-image-metas.pkl", 71 | image_level_ann_file="annotations/validation-" "annotations-human-imagelabels-boxable.csv", 72 | pipeline=test_pipeline, 73 | backend_args=backend_args, 74 | ), 75 | ) 76 | test_dataloader = val_dataloader 77 | 78 | val_evaluator = dict(type="OpenImagesMetric", iou_thrs=0.5, ioa_thrs=0.5, use_group_of=True, get_supercategory=True) 79 | test_evaluator = val_evaluator 80 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/voc0712.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "VOCDataset" 3 | data_root = "data/VOCdevkit/" 4 | 5 | # Example to use different file client 6 | # Method 1: simply set the data root and let the file I/O module 7 | # automatically Infer from prefix (not support LMDB and Memcache yet) 8 | 9 | # data_root = 's3://openmmlab/datasets/detection/segmentation/VOCdevkit/' 10 | 11 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 12 | # backend_args = dict( 13 | # backend='petrel', 14 | # path_mapping=dict({ 15 | # './data/': 's3://openmmlab/datasets/segmentation/', 16 | # 'data/': 's3://openmmlab/datasets/segmentation/' 17 | # })) 18 | backend_args = None 19 | 20 | train_pipeline = [ 21 | dict(type="LoadImageFromFile", backend_args=backend_args), 22 | dict(type="LoadAnnotations", with_bbox=True), 23 | dict(type="Resize", scale=(1000, 600), keep_ratio=True), 24 | dict(type="RandomFlip", prob=0.5), 25 | dict(type="PackDetInputs"), 26 | ] 27 | test_pipeline = [ 28 | dict(type="LoadImageFromFile", backend_args=backend_args), 29 | dict(type="Resize", scale=(1000, 600), keep_ratio=True), 30 | # avoid bboxes being resized 31 | dict(type="LoadAnnotations", with_bbox=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type="DefaultSampler", shuffle=True), 39 | batch_sampler=dict(type="AspectRatioBatchSampler"), 40 | dataset=dict( 41 | type="RepeatDataset", 42 | times=3, 43 | dataset=dict( 44 | type="ConcatDataset", 45 | # VOCDataset will add different `dataset_type` in dataset.metainfo, 46 | # which will get error if using ConcatDataset. Adding 47 | # `ignore_keys` can avoid this error. 48 | ignore_keys=["dataset_type"], 49 | datasets=[ 50 | dict( 51 | type=dataset_type, 52 | data_root=data_root, 53 | ann_file="VOC2007/ImageSets/Main/trainval.txt", 54 | data_prefix=dict(sub_data_root="VOC2007/"), 55 | filter_cfg=dict(filter_empty_gt=True, min_size=32, bbox_min_size=32), 56 | pipeline=train_pipeline, 57 | backend_args=backend_args, 58 | ), 59 | dict( 60 | type=dataset_type, 61 | data_root=data_root, 62 | ann_file="VOC2012/ImageSets/Main/trainval.txt", 63 | data_prefix=dict(sub_data_root="VOC2012/"), 64 | filter_cfg=dict(filter_empty_gt=True, min_size=32, bbox_min_size=32), 65 | pipeline=train_pipeline, 66 | backend_args=backend_args, 67 | ), 68 | ], 69 | ), 70 | ), 71 | ) 72 | 73 | val_dataloader = dict( 74 | batch_size=1, 75 | num_workers=2, 76 | persistent_workers=True, 77 | drop_last=False, 78 | sampler=dict(type="DefaultSampler", shuffle=False), 79 | dataset=dict( 80 | type=dataset_type, 81 | data_root=data_root, 82 | ann_file="VOC2007/ImageSets/Main/test.txt", 83 | data_prefix=dict(sub_data_root="VOC2007/"), 84 | test_mode=True, 85 | pipeline=test_pipeline, 86 | backend_args=backend_args, 87 | ), 88 | ) 89 | test_dataloader = val_dataloader 90 | 91 | # Pascal VOC2007 uses `11points` as default evaluate mode, while PASCAL 92 | # VOC2012 defaults to use 'area'. 93 | val_evaluator = dict(type="VOCMetric", metric="mAP", eval_mode="11points") 94 | test_evaluator = val_evaluator 95 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/datasets/wider_face.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "WIDERFaceDataset" 3 | data_root = "data/WIDERFace/" 4 | # Example to use different file client 5 | # Method 1: simply set the data root and let the file I/O module 6 | # automatically infer from prefix (not support LMDB and Memcache yet) 7 | 8 | # data_root = 's3://openmmlab/datasets/detection/cityscapes/' 9 | 10 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 11 | # backend_args = dict( 12 | # backend='petrel', 13 | # path_mapping=dict({ 14 | # './data/': 's3://openmmlab/datasets/detection/', 15 | # 'data/': 's3://openmmlab/datasets/detection/' 16 | # })) 17 | backend_args = None 18 | 19 | img_scale = (640, 640) # VGA resolution 20 | 21 | train_pipeline = [ 22 | dict(type="LoadImageFromFile", backend_args=backend_args), 23 | dict(type="LoadAnnotations", with_bbox=True), 24 | dict(type="Resize", scale=img_scale, keep_ratio=True), 25 | dict(type="RandomFlip", prob=0.5), 26 | dict(type="PackDetInputs"), 27 | ] 28 | test_pipeline = [ 29 | dict(type="LoadImageFromFile", backend_args=backend_args), 30 | dict(type="Resize", scale=img_scale, keep_ratio=True), 31 | dict(type="LoadAnnotations", with_bbox=True), 32 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 33 | ] 34 | 35 | train_dataloader = dict( 36 | batch_size=2, 37 | num_workers=2, 38 | persistent_workers=True, 39 | drop_last=False, 40 | sampler=dict(type="DefaultSampler", shuffle=True), 41 | batch_sampler=dict(type="AspectRatioBatchSampler"), 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | ann_file="train.txt", 46 | data_prefix=dict(img="WIDER_train"), 47 | filter_cfg=dict(filter_empty_gt=True, bbox_min_size=17, min_size=32), 48 | pipeline=train_pipeline, 49 | ), 50 | ) 51 | 52 | val_dataloader = dict( 53 | batch_size=1, 54 | num_workers=2, 55 | persistent_workers=True, 56 | drop_last=False, 57 | sampler=dict(type="DefaultSampler", shuffle=False), 58 | dataset=dict( 59 | type=dataset_type, 60 | data_root=data_root, 61 | ann_file="val.txt", 62 | data_prefix=dict(img="WIDER_val"), 63 | test_mode=True, 64 | pipeline=test_pipeline, 65 | ), 66 | ) 67 | test_dataloader = val_dataloader 68 | 69 | val_evaluator = dict( 70 | # TODO: support WiderFace-Evaluation for easy, medium, hard cases 71 | type="VOCMetric", 72 | metric="mAP", 73 | eval_mode="11points", 74 | ) 75 | test_evaluator = val_evaluator 76 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | default_scope = "mmdet" 2 | 3 | default_hooks = dict( 4 | timer=dict(type="IterTimerHook"), 5 | logger=dict(type="LoggerHook", interval=50), 6 | param_scheduler=dict(type="ParamSchedulerHook"), 7 | checkpoint=dict(type="CheckpointHook", interval=1), 8 | sampler_seed=dict(type="DistSamplerSeedHook"), 9 | visualization=dict(type="DetVisualizationHook"), 10 | ) 11 | 12 | env_cfg = dict( 13 | cudnn_benchmark=False, 14 | mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0), 15 | dist_cfg=dict(backend="nccl"), 16 | ) 17 | 18 | vis_backends = [dict(type="LocalVisBackend")] 19 | visualizer = dict(type="DetLocalVisualizer", vis_backends=vis_backends, name="visualizer") 20 | log_processor = dict(type="LogProcessor", window_size=50, by_epoch=True) 21 | 22 | log_level = "INFO" 23 | load_from = None 24 | resume = False 25 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/fast-rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="FastRCNN", 4 | data_preprocessor=dict( 5 | type="DetDataPreprocessor", 6 | mean=[123.675, 116.28, 103.53], 7 | std=[58.395, 57.12, 57.375], 8 | bgr_to_rgb=True, 9 | pad_size_divisor=32, 10 | ), 11 | backbone=dict( 12 | type="ResNet", 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | frozen_stages=1, 17 | norm_cfg=dict(type="BN", requires_grad=True), 18 | norm_eval=True, 19 | style="pytorch", 20 | init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), 21 | ), 22 | neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), 23 | roi_head=dict( 24 | type="StandardRoIHead", 25 | bbox_roi_extractor=dict( 26 | type="SingleRoIExtractor", 27 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 28 | out_channels=256, 29 | featmap_strides=[4, 8, 16, 32], 30 | ), 31 | bbox_head=dict( 32 | type="Shared2FCBBoxHead", 33 | in_channels=256, 34 | fc_out_channels=1024, 35 | roi_feat_size=7, 36 | num_classes=80, 37 | bbox_coder=dict( 38 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 39 | ), 40 | reg_class_agnostic=False, 41 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 42 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 43 | ), 44 | ), 45 | # model training and testing settings 46 | train_cfg=dict( 47 | rcnn=dict( 48 | assigner=dict( 49 | type="MaxIoUAssigner", 50 | pos_iou_thr=0.5, 51 | neg_iou_thr=0.5, 52 | min_pos_iou=0.5, 53 | match_low_quality=False, 54 | ignore_iof_thr=-1, 55 | ), 56 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 57 | pos_weight=-1, 58 | debug=False, 59 | ) 60 | ), 61 | test_cfg=dict(rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100)), 62 | ) 63 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/faster-rcnn_r50-caffe-c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type="BN", requires_grad=False) 3 | model = dict( 4 | type="FasterRCNN", 5 | data_preprocessor=dict( 6 | type="DetDataPreprocessor", 7 | mean=[103.530, 116.280, 123.675], 8 | std=[1.0, 1.0, 1.0], 9 | bgr_to_rgb=False, 10 | pad_size_divisor=32, 11 | ), 12 | backbone=dict( 13 | type="ResNet", 14 | depth=50, 15 | num_stages=3, 16 | strides=(1, 2, 2), 17 | dilations=(1, 1, 1), 18 | out_indices=(2,), 19 | frozen_stages=1, 20 | norm_cfg=norm_cfg, 21 | norm_eval=True, 22 | style="caffe", 23 | init_cfg=dict(type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"), 24 | ), 25 | rpn_head=dict( 26 | type="RPNHead", 27 | in_channels=1024, 28 | feat_channels=1024, 29 | anchor_generator=dict(type="AnchorGenerator", scales=[2, 4, 8, 16, 32], ratios=[0.5, 1.0, 2.0], strides=[16]), 30 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 31 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 33 | ), 34 | roi_head=dict( 35 | type="StandardRoIHead", 36 | shared_head=dict( 37 | type="ResLayer", 38 | depth=50, 39 | stage=3, 40 | stride=2, 41 | dilation=1, 42 | style="caffe", 43 | norm_cfg=norm_cfg, 44 | norm_eval=True, 45 | init_cfg=dict(type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"), 46 | ), 47 | bbox_roi_extractor=dict( 48 | type="SingleRoIExtractor", 49 | roi_layer=dict(type="RoIAlign", output_size=14, sampling_ratio=0), 50 | out_channels=1024, 51 | featmap_strides=[16], 52 | ), 53 | bbox_head=dict( 54 | type="BBoxHead", 55 | with_avg_pool=True, 56 | roi_feat_size=7, 57 | in_channels=2048, 58 | num_classes=80, 59 | bbox_coder=dict( 60 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 61 | ), 62 | reg_class_agnostic=False, 63 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 64 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 65 | ), 66 | ), 67 | # model training and testing settings 68 | train_cfg=dict( 69 | rpn=dict( 70 | assigner=dict( 71 | type="MaxIoUAssigner", 72 | pos_iou_thr=0.7, 73 | neg_iou_thr=0.3, 74 | min_pos_iou=0.3, 75 | match_low_quality=True, 76 | ignore_iof_thr=-1, 77 | ), 78 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 79 | allowed_border=-1, 80 | pos_weight=-1, 81 | debug=False, 82 | ), 83 | rpn_proposal=dict(nms_pre=12000, max_per_img=2000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 84 | rcnn=dict( 85 | assigner=dict( 86 | type="MaxIoUAssigner", 87 | pos_iou_thr=0.5, 88 | neg_iou_thr=0.5, 89 | min_pos_iou=0.5, 90 | match_low_quality=False, 91 | ignore_iof_thr=-1, 92 | ), 93 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 94 | pos_weight=-1, 95 | debug=False, 96 | ), 97 | ), 98 | test_cfg=dict( 99 | rpn=dict(nms_pre=6000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 100 | rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100), 101 | ), 102 | ) 103 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/faster-rcnn_r50-caffe-dc5.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type="BN", requires_grad=False) 3 | model = dict( 4 | type="FasterRCNN", 5 | data_preprocessor=dict( 6 | type="DetDataPreprocessor", 7 | mean=[103.530, 116.280, 123.675], 8 | std=[1.0, 1.0, 1.0], 9 | bgr_to_rgb=False, 10 | pad_size_divisor=32, 11 | ), 12 | backbone=dict( 13 | type="ResNet", 14 | depth=50, 15 | num_stages=4, 16 | strides=(1, 2, 2, 1), 17 | dilations=(1, 1, 1, 2), 18 | out_indices=(3,), 19 | frozen_stages=1, 20 | norm_cfg=norm_cfg, 21 | norm_eval=True, 22 | style="caffe", 23 | init_cfg=dict(type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"), 24 | ), 25 | rpn_head=dict( 26 | type="RPNHead", 27 | in_channels=2048, 28 | feat_channels=2048, 29 | anchor_generator=dict(type="AnchorGenerator", scales=[2, 4, 8, 16, 32], ratios=[0.5, 1.0, 2.0], strides=[16]), 30 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 31 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 33 | ), 34 | roi_head=dict( 35 | type="StandardRoIHead", 36 | bbox_roi_extractor=dict( 37 | type="SingleRoIExtractor", 38 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 39 | out_channels=2048, 40 | featmap_strides=[16], 41 | ), 42 | bbox_head=dict( 43 | type="Shared2FCBBoxHead", 44 | in_channels=2048, 45 | fc_out_channels=1024, 46 | roi_feat_size=7, 47 | num_classes=80, 48 | bbox_coder=dict( 49 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 50 | ), 51 | reg_class_agnostic=False, 52 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 53 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 54 | ), 55 | ), 56 | # model training and testing settings 57 | train_cfg=dict( 58 | rpn=dict( 59 | assigner=dict( 60 | type="MaxIoUAssigner", 61 | pos_iou_thr=0.7, 62 | neg_iou_thr=0.3, 63 | min_pos_iou=0.3, 64 | match_low_quality=True, 65 | ignore_iof_thr=-1, 66 | ), 67 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 68 | allowed_border=0, 69 | pos_weight=-1, 70 | debug=False, 71 | ), 72 | rpn_proposal=dict(nms_pre=12000, max_per_img=2000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 73 | rcnn=dict( 74 | assigner=dict( 75 | type="MaxIoUAssigner", 76 | pos_iou_thr=0.5, 77 | neg_iou_thr=0.5, 78 | min_pos_iou=0.5, 79 | match_low_quality=False, 80 | ignore_iof_thr=-1, 81 | ), 82 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 83 | pos_weight=-1, 84 | debug=False, 85 | ), 86 | ), 87 | test_cfg=dict( 88 | rpn=dict(nms=dict(type="nms", iou_threshold=0.7), nms_pre=6000, max_per_img=1000, min_bbox_size=0), 89 | rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100), 90 | ), 91 | ) 92 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/faster-rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="FasterRCNN", 4 | data_preprocessor=dict( 5 | type="DetDataPreprocessor", 6 | mean=[123.675, 116.28, 103.53], 7 | std=[58.395, 57.12, 57.375], 8 | bgr_to_rgb=True, 9 | pad_size_divisor=32, 10 | ), 11 | backbone=dict( 12 | type="ResNet", 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | frozen_stages=1, 17 | norm_cfg=dict(type="BN", requires_grad=True), 18 | norm_eval=True, 19 | style="pytorch", 20 | init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), 21 | ), 22 | neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), 23 | rpn_head=dict( 24 | type="RPNHead", 25 | in_channels=256, 26 | feat_channels=256, 27 | anchor_generator=dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), 28 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 29 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 30 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 31 | ), 32 | roi_head=dict( 33 | type="StandardRoIHead", 34 | bbox_roi_extractor=dict( 35 | type="SingleRoIExtractor", 36 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 37 | out_channels=256, 38 | featmap_strides=[4, 8, 16, 32], 39 | ), 40 | bbox_head=dict( 41 | type="Shared2FCBBoxHead", 42 | in_channels=256, 43 | fc_out_channels=1024, 44 | roi_feat_size=7, 45 | num_classes=80, 46 | bbox_coder=dict( 47 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 48 | ), 49 | reg_class_agnostic=False, 50 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 51 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 52 | ), 53 | ), 54 | # model training and testing settings 55 | train_cfg=dict( 56 | rpn=dict( 57 | assigner=dict( 58 | type="MaxIoUAssigner", 59 | pos_iou_thr=0.7, 60 | neg_iou_thr=0.3, 61 | min_pos_iou=0.3, 62 | match_low_quality=True, 63 | ignore_iof_thr=-1, 64 | ), 65 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 66 | allowed_border=-1, 67 | pos_weight=-1, 68 | debug=False, 69 | ), 70 | rpn_proposal=dict(nms_pre=2000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 71 | rcnn=dict( 72 | assigner=dict( 73 | type="MaxIoUAssigner", 74 | pos_iou_thr=0.5, 75 | neg_iou_thr=0.5, 76 | min_pos_iou=0.5, 77 | match_low_quality=False, 78 | ignore_iof_thr=-1, 79 | ), 80 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 81 | pos_weight=-1, 82 | debug=False, 83 | ), 84 | ), 85 | test_cfg=dict( 86 | rpn=dict(nms_pre=1000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 87 | rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100), 88 | # soft-nms is also supported for rcnn testing 89 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) 90 | ), 91 | ) 92 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/mask-rcnn_r50-caffe-c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type="BN", requires_grad=False) 3 | model = dict( 4 | type="MaskRCNN", 5 | data_preprocessor=dict( 6 | type="DetDataPreprocessor", 7 | mean=[103.530, 116.280, 123.675], 8 | std=[1.0, 1.0, 1.0], 9 | bgr_to_rgb=False, 10 | pad_mask=True, 11 | pad_size_divisor=32, 12 | ), 13 | backbone=dict( 14 | type="ResNet", 15 | depth=50, 16 | num_stages=3, 17 | strides=(1, 2, 2), 18 | dilations=(1, 1, 1), 19 | out_indices=(2,), 20 | frozen_stages=1, 21 | norm_cfg=norm_cfg, 22 | norm_eval=True, 23 | style="caffe", 24 | init_cfg=dict(type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"), 25 | ), 26 | rpn_head=dict( 27 | type="RPNHead", 28 | in_channels=1024, 29 | feat_channels=1024, 30 | anchor_generator=dict(type="AnchorGenerator", scales=[2, 4, 8, 16, 32], ratios=[0.5, 1.0, 2.0], strides=[16]), 31 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 33 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 34 | ), 35 | roi_head=dict( 36 | type="StandardRoIHead", 37 | shared_head=dict( 38 | type="ResLayer", depth=50, stage=3, stride=2, dilation=1, style="caffe", norm_cfg=norm_cfg, norm_eval=True 39 | ), 40 | bbox_roi_extractor=dict( 41 | type="SingleRoIExtractor", 42 | roi_layer=dict(type="RoIAlign", output_size=14, sampling_ratio=0), 43 | out_channels=1024, 44 | featmap_strides=[16], 45 | ), 46 | bbox_head=dict( 47 | type="BBoxHead", 48 | with_avg_pool=True, 49 | roi_feat_size=7, 50 | in_channels=2048, 51 | num_classes=80, 52 | bbox_coder=dict( 53 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 54 | ), 55 | reg_class_agnostic=False, 56 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 57 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 58 | ), 59 | mask_roi_extractor=None, 60 | mask_head=dict( 61 | type="FCNMaskHead", 62 | num_convs=0, 63 | in_channels=2048, 64 | conv_out_channels=256, 65 | num_classes=80, 66 | loss_mask=dict(type="CrossEntropyLoss", use_mask=True, loss_weight=1.0), 67 | ), 68 | ), 69 | # model training and testing settings 70 | train_cfg=dict( 71 | rpn=dict( 72 | assigner=dict( 73 | type="MaxIoUAssigner", 74 | pos_iou_thr=0.7, 75 | neg_iou_thr=0.3, 76 | min_pos_iou=0.3, 77 | match_low_quality=True, 78 | ignore_iof_thr=-1, 79 | ), 80 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 81 | allowed_border=0, 82 | pos_weight=-1, 83 | debug=False, 84 | ), 85 | rpn_proposal=dict(nms_pre=12000, max_per_img=2000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 86 | rcnn=dict( 87 | assigner=dict( 88 | type="MaxIoUAssigner", 89 | pos_iou_thr=0.5, 90 | neg_iou_thr=0.5, 91 | min_pos_iou=0.5, 92 | match_low_quality=False, 93 | ignore_iof_thr=-1, 94 | ), 95 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 96 | mask_size=14, 97 | pos_weight=-1, 98 | debug=False, 99 | ), 100 | ), 101 | test_cfg=dict( 102 | rpn=dict(nms_pre=6000, nms=dict(type="nms", iou_threshold=0.7), max_per_img=1000, min_bbox_size=0), 103 | rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5), 104 | ), 105 | ) 106 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/mask-rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="MaskRCNN", 4 | data_preprocessor=dict( 5 | type="DetDataPreprocessor", 6 | mean=[123.675, 116.28, 103.53], 7 | std=[58.395, 57.12, 57.375], 8 | bgr_to_rgb=True, 9 | pad_mask=True, 10 | pad_size_divisor=32, 11 | ), 12 | backbone=dict( 13 | type="ResNet", 14 | depth=50, 15 | num_stages=4, 16 | out_indices=(0, 1, 2, 3), 17 | frozen_stages=1, 18 | norm_cfg=dict(type="BN", requires_grad=True), 19 | norm_eval=True, 20 | style="pytorch", 21 | init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), 22 | ), 23 | neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), 24 | rpn_head=dict( 25 | type="RPNHead", 26 | in_channels=256, 27 | feat_channels=256, 28 | anchor_generator=dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), 29 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 30 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 31 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 32 | ), 33 | roi_head=dict( 34 | type="StandardRoIHead", 35 | bbox_roi_extractor=dict( 36 | type="SingleRoIExtractor", 37 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 38 | out_channels=256, 39 | featmap_strides=[4, 8, 16, 32], 40 | ), 41 | bbox_head=dict( 42 | type="Shared2FCBBoxHead", 43 | in_channels=256, 44 | fc_out_channels=1024, 45 | roi_feat_size=7, 46 | num_classes=80, 47 | bbox_coder=dict( 48 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 49 | ), 50 | reg_class_agnostic=False, 51 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 52 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 53 | ), 54 | mask_roi_extractor=dict( 55 | type="SingleRoIExtractor", 56 | roi_layer=dict(type="RoIAlign", output_size=14, sampling_ratio=0), 57 | out_channels=256, 58 | featmap_strides=[4, 8, 16, 32], 59 | ), 60 | mask_head=dict( 61 | type="FCNMaskHead", 62 | num_convs=4, 63 | in_channels=256, 64 | conv_out_channels=256, 65 | num_classes=80, 66 | loss_mask=dict(type="CrossEntropyLoss", use_mask=True, loss_weight=1.0), 67 | ), 68 | ), 69 | # model training and testing settings 70 | train_cfg=dict( 71 | rpn=dict( 72 | assigner=dict( 73 | type="MaxIoUAssigner", 74 | pos_iou_thr=0.7, 75 | neg_iou_thr=0.3, 76 | min_pos_iou=0.3, 77 | match_low_quality=True, 78 | ignore_iof_thr=-1, 79 | ), 80 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 81 | allowed_border=-1, 82 | pos_weight=-1, 83 | debug=False, 84 | ), 85 | rpn_proposal=dict(nms_pre=2000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 86 | rcnn=dict( 87 | assigner=dict( 88 | type="MaxIoUAssigner", 89 | pos_iou_thr=0.5, 90 | neg_iou_thr=0.5, 91 | min_pos_iou=0.5, 92 | match_low_quality=True, 93 | ignore_iof_thr=-1, 94 | ), 95 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 96 | mask_size=28, 97 | pos_weight=-1, 98 | debug=False, 99 | ), 100 | ), 101 | test_cfg=dict( 102 | rpn=dict(nms_pre=1000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), 103 | rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5), 104 | ), 105 | ) 106 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/retinanet_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="RetinaNet", 4 | data_preprocessor=dict( 5 | type="DetDataPreprocessor", 6 | mean=[123.675, 116.28, 103.53], 7 | std=[58.395, 57.12, 57.375], 8 | bgr_to_rgb=True, 9 | pad_size_divisor=32, 10 | ), 11 | backbone=dict( 12 | type="ResNet", 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | frozen_stages=1, 17 | norm_cfg=dict(type="BN", requires_grad=True), 18 | norm_eval=True, 19 | style="pytorch", 20 | init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), 21 | ), 22 | neck=dict( 23 | type="FPN", 24 | in_channels=[256, 512, 1024, 2048], 25 | out_channels=256, 26 | start_level=1, 27 | add_extra_convs="on_input", 28 | num_outs=5, 29 | ), 30 | bbox_head=dict( 31 | type="RetinaHead", 32 | num_classes=80, 33 | in_channels=256, 34 | stacked_convs=4, 35 | feat_channels=256, 36 | anchor_generator=dict( 37 | type="AnchorGenerator", 38 | octave_base_scale=4, 39 | scales_per_octave=3, 40 | ratios=[0.5, 1.0, 2.0], 41 | strides=[8, 16, 32, 64, 128], 42 | ), 43 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 44 | loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), 45 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 46 | ), 47 | # model training and testing settings 48 | train_cfg=dict( 49 | assigner=dict(type="MaxIoUAssigner", pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), 50 | sampler=dict(type="PseudoSampler"), # Focal loss should use PseudoSampler 51 | allowed_border=-1, 52 | pos_weight=-1, 53 | debug=False, 54 | ), 55 | test_cfg=dict( 56 | nms_pre=1000, min_bbox_size=0, score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100 57 | ), 58 | ) 59 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/rpn_r50-caffe-c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="RPN", 4 | data_preprocessor=dict( 5 | type="DetDataPreprocessor", 6 | mean=[103.530, 116.280, 123.675], 7 | std=[1.0, 1.0, 1.0], 8 | bgr_to_rgb=False, 9 | pad_size_divisor=32, 10 | ), 11 | backbone=dict( 12 | type="ResNet", 13 | depth=50, 14 | num_stages=3, 15 | strides=(1, 2, 2), 16 | dilations=(1, 1, 1), 17 | out_indices=(2,), 18 | frozen_stages=1, 19 | norm_cfg=dict(type="BN", requires_grad=False), 20 | norm_eval=True, 21 | style="caffe", 22 | init_cfg=dict(type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"), 23 | ), 24 | neck=None, 25 | rpn_head=dict( 26 | type="RPNHead", 27 | in_channels=1024, 28 | feat_channels=1024, 29 | anchor_generator=dict(type="AnchorGenerator", scales=[2, 4, 8, 16, 32], ratios=[0.5, 1.0, 2.0], strides=[16]), 30 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 31 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 33 | ), 34 | # model training and testing settings 35 | train_cfg=dict( 36 | rpn=dict( 37 | assigner=dict(type="MaxIoUAssigner", pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), 38 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 39 | allowed_border=-1, 40 | pos_weight=-1, 41 | debug=False, 42 | ) 43 | ), 44 | test_cfg=dict(rpn=dict(nms_pre=12000, max_per_img=2000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0)), 45 | ) 46 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/rpn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="RPN", 4 | data_preprocessor=dict( 5 | type="DetDataPreprocessor", 6 | mean=[123.675, 116.28, 103.53], 7 | std=[58.395, 57.12, 57.375], 8 | bgr_to_rgb=True, 9 | pad_size_divisor=32, 10 | ), 11 | backbone=dict( 12 | type="ResNet", 13 | depth=50, 14 | num_stages=4, 15 | out_indices=(0, 1, 2, 3), 16 | frozen_stages=1, 17 | norm_cfg=dict(type="BN", requires_grad=True), 18 | norm_eval=True, 19 | style="pytorch", 20 | init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), 21 | ), 22 | neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), 23 | rpn_head=dict( 24 | type="RPNHead", 25 | in_channels=256, 26 | feat_channels=256, 27 | anchor_generator=dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), 28 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 29 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 30 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 31 | ), 32 | # model training and testing settings 33 | train_cfg=dict( 34 | rpn=dict( 35 | assigner=dict(type="MaxIoUAssigner", pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1), 36 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 37 | allowed_border=-1, 38 | pos_weight=-1, 39 | debug=False, 40 | ) 41 | ), 42 | test_cfg=dict(rpn=dict(nms_pre=2000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0)), 43 | ) 44 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/models/ssd300.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | input_size = 300 3 | model = dict( 4 | type="SingleStageDetector", 5 | data_preprocessor=dict( 6 | type="DetDataPreprocessor", mean=[123.675, 116.28, 103.53], std=[1, 1, 1], bgr_to_rgb=True, pad_size_divisor=1 7 | ), 8 | backbone=dict( 9 | type="SSDVGG", 10 | depth=16, 11 | with_last_pool=False, 12 | ceil_mode=True, 13 | out_indices=(3, 4), 14 | out_feature_indices=(22, 34), 15 | init_cfg=dict(type="Pretrained", checkpoint="open-mmlab://vgg16_caffe"), 16 | ), 17 | neck=dict( 18 | type="SSDNeck", 19 | in_channels=(512, 1024), 20 | out_channels=(512, 1024, 512, 256, 256, 256), 21 | level_strides=(2, 2, 1, 1), 22 | level_paddings=(1, 1, 0, 0), 23 | l2_norm_scale=20, 24 | ), 25 | bbox_head=dict( 26 | type="SSDHead", 27 | in_channels=(512, 1024, 512, 256, 256, 256), 28 | num_classes=80, 29 | anchor_generator=dict( 30 | type="SSDAnchorGenerator", 31 | scale_major=False, 32 | input_size=input_size, 33 | basesize_ratio_range=(0.15, 0.9), 34 | strides=[8, 16, 32, 64, 100, 300], 35 | ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]], 36 | ), 37 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2]), 38 | ), 39 | # model training and testing settings 40 | train_cfg=dict( 41 | assigner=dict( 42 | type="MaxIoUAssigner", 43 | pos_iou_thr=0.5, 44 | neg_iou_thr=0.5, 45 | min_pos_iou=0.0, 46 | ignore_iof_thr=-1, 47 | gt_max_assign_all=False, 48 | ), 49 | sampler=dict(type="PseudoSampler"), 50 | smoothl1_beta=1.0, 51 | allowed_border=-1, 52 | pos_weight=-1, 53 | neg_pos_ratio=3, 54 | debug=False, 55 | ), 56 | test_cfg=dict( 57 | nms_pre=1000, nms=dict(type="nms", iou_threshold=0.45), min_bbox_size=0, score_thr=0.02, max_per_img=200 58 | ), 59 | ) 60 | cudnn_benchmark = True 61 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | train_cfg = dict(type="EpochBasedTrainLoop", max_epochs=12, val_interval=1) 3 | val_cfg = dict(type="ValLoop") 4 | test_cfg = dict(type="TestLoop") 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict(type="LinearLR", start_factor=0.001, by_epoch=False, begin=0, end=500), 9 | dict(type="MultiStepLR", begin=0, end=12, by_epoch=True, milestones=[8, 11], gamma=0.1), 10 | ] 11 | 12 | # optimizer 13 | optim_wrapper = dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001)) 14 | 15 | # Default setting for scaling LR automatically 16 | # - `enable` means enable scaling LR automatically 17 | # or not by default. 18 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 19 | auto_scale_lr = dict(enable=False, base_batch_size=16) 20 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/schedules/schedule_20e.py: -------------------------------------------------------------------------------- 1 | # training schedule for 20e 2 | train_cfg = dict(type="EpochBasedTrainLoop", max_epochs=20, val_interval=1) 3 | val_cfg = dict(type="ValLoop") 4 | test_cfg = dict(type="TestLoop") 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict(type="LinearLR", start_factor=0.001, by_epoch=False, begin=0, end=500), 9 | dict(type="MultiStepLR", begin=0, end=20, by_epoch=True, milestones=[16, 19], gamma=0.1), 10 | ] 11 | 12 | # optimizer 13 | optim_wrapper = dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001)) 14 | 15 | # Default setting for scaling LR automatically 16 | # - `enable` means enable scaling LR automatically 17 | # or not by default. 18 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 19 | auto_scale_lr = dict(enable=False, base_batch_size=16) 20 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/_base_/schedules/schedule_2x.py: -------------------------------------------------------------------------------- 1 | # training schedule for 2x 2 | train_cfg = dict(type="EpochBasedTrainLoop", max_epochs=24, val_interval=1) 3 | val_cfg = dict(type="ValLoop") 4 | test_cfg = dict(type="TestLoop") 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict(type="LinearLR", start_factor=0.001, by_epoch=False, begin=0, end=500), 9 | dict(type="MultiStepLR", begin=0, end=24, by_epoch=True, milestones=[16, 22], gamma=0.1), 10 | ] 11 | 12 | # optimizer 13 | optim_wrapper = dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001)) 14 | 15 | # Default setting for scaling LR automatically 16 | # - `enable` means enable scaling LR automatically 17 | # or not by default. 18 | # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). 19 | auto_scale_lr = dict(enable=False, base_batch_size=16) 20 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/cascade_mask_rcnn/cascade-mask-rcnn_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/cascade-mask-rcnn_r50_fpn.py", 3 | "../_base_/datasets/coco_instance.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/retinanet/retinanet_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/retinanet_r50_fpn.py", 3 | "../_base_/datasets/coco_detection.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | "./retinanet_tta.py", 7 | ] 8 | 9 | # optimizer 10 | optim_wrapper = dict(optimizer=dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001)) 11 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/retinanet/retinanet_tta.py: -------------------------------------------------------------------------------- 1 | tta_model = dict(type="DetTTAModel", tta_cfg=dict(nms=dict(type="nms", iou_threshold=0.5), max_per_img=100)) 2 | 3 | img_scales = [(1333, 800), (666, 400), (2000, 1200)] 4 | tta_pipeline = [ 5 | dict(type="LoadImageFromFile", backend_args=None), 6 | dict( 7 | type="TestTimeAug", 8 | transforms=[ 9 | [dict(type="Resize", scale=s, keep_ratio=True) for s in img_scales], 10 | [dict(type="RandomFlip", prob=1.0), dict(type="RandomFlip", prob=0.0)], 11 | [dict(type="LoadAnnotations", with_bbox=True)], 12 | [ 13 | dict( 14 | type="PackDetInputs", 15 | meta_keys=( 16 | "img_id", 17 | "img_path", 18 | "ori_shape", 19 | "img_shape", 20 | "scale_factor", 21 | "flip", 22 | "flip_direction", 23 | ), 24 | ) 25 | ], 26 | ], 27 | ), 28 | ] 29 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/yolox/yolox_tiny_8xb8-300e_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = "./yolox_s_8xb8-300e_coco.py" 2 | 3 | # model settings 4 | model = dict( 5 | data_preprocessor=dict( 6 | batch_augments=[dict(type="BatchSyncRandomResize", random_size_range=(320, 640), size_divisor=32, interval=10)] 7 | ), 8 | backbone=dict(deepen_factor=0.33, widen_factor=0.375), 9 | neck=dict(in_channels=[96, 192, 384], out_channels=96), 10 | bbox_head=dict(in_channels=96, feat_channels=96), 11 | ) 12 | 13 | img_scale = (640, 640) # width, height 14 | 15 | train_pipeline = [ 16 | dict(type="Mosaic", img_scale=img_scale, pad_val=114.0), 17 | dict( 18 | type="RandomAffine", 19 | scaling_ratio_range=(0.5, 1.5), 20 | # img_scale is (width, height) 21 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 22 | ), 23 | dict(type="YOLOXHSVRandomAug"), 24 | dict(type="RandomFlip", prob=0.5), 25 | # Resize and Pad are for the last 15 epochs when Mosaic and 26 | # RandomAffine are closed by YOLOXModeSwitchHook. 27 | dict(type="Resize", scale=img_scale, keep_ratio=True), 28 | dict(type="Pad", pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0))), 29 | dict(type="FilterAnnotations", min_gt_bbox_wh=(1, 1), keep_empty=False), 30 | dict(type="PackDetInputs"), 31 | ] 32 | 33 | test_pipeline = [ 34 | dict(type="LoadImageFromFile", backend_args={{_base_.backend_args}}), 35 | dict(type="Resize", scale=(416, 416), keep_ratio=True), 36 | dict(type="Pad", pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0))), 37 | dict(type="LoadAnnotations", with_bbox=True), 38 | dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), 39 | ] 40 | 41 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) 42 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) 43 | test_dataloader = val_dataloader 44 | -------------------------------------------------------------------------------- /tests/data/models/mmdet/yolox/yolox_tta.py: -------------------------------------------------------------------------------- 1 | tta_model = dict(type="DetTTAModel", tta_cfg=dict(nms=dict(type="nms", iou_threshold=0.65), max_per_img=100)) 2 | 3 | img_scales = [(640, 640), (320, 320), (960, 960)] 4 | tta_pipeline = [ 5 | dict(type="LoadImageFromFile", backend_args=None), 6 | dict( 7 | type="TestTimeAug", 8 | transforms=[ 9 | [dict(type="Resize", scale=s, keep_ratio=True) for s in img_scales], 10 | [ 11 | # ``RandomFlip`` must be placed before ``Pad``, otherwise 12 | # bounding box coordinates after flipping cannot be 13 | # recovered correctly. 14 | dict(type="RandomFlip", prob=1.0), 15 | dict(type="RandomFlip", prob=0.0), 16 | ], 17 | [ 18 | dict(type="Pad", pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0))), 19 | ], 20 | [dict(type="LoadAnnotations", with_bbox=True)], 21 | [ 22 | dict( 23 | type="PackDetInputs", 24 | meta_keys=( 25 | "img_id", 26 | "img_path", 27 | "ori_shape", 28 | "img_shape", 29 | "scale_factor", 30 | "flip", 31 | "flip_direction", 32 | ), 33 | ) 34 | ], 35 | ], 36 | ), 37 | ] 38 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_cascade_mask_rcnn/cascade-mask-rcnn_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/cascade-mask-rcnn_r50_fpn.py", 3 | "../_base_/datasets/coco_instance.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ["cascade_mask_rcnn_r50_fpn.py", "coco_instance.py", "schedule_1x.py", "default_runtime.py"] 2 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco_v280.py: -------------------------------------------------------------------------------- 1 | _base_ = ["cascade_mask_rcnn_r50_fpn_v280.py", "coco_instance.py", "schedule_1x.py", "default_runtime.py"] 2 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_cascade_mask_rcnn/coco_instance.py: -------------------------------------------------------------------------------- 1 | dataset_type = "CocoDataset" 2 | data_root = "data/coco/" 3 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | train_pipeline = [ 5 | dict(type="LoadImageFromFile"), 6 | dict(type="LoadAnnotations", with_bbox=True, with_mask=True), 7 | dict(type="Resize", img_scale=(1333, 800), keep_ratio=True), 8 | dict(type="RandomFlip", flip_ratio=0.5), 9 | dict(type="Normalize", **img_norm_cfg), 10 | dict(type="Pad", size_divisor=32), 11 | dict(type="DefaultFormatBundle"), 12 | dict(type="Collect", keys=["img", "gt_bboxes", "gt_labels", "gt_masks"]), 13 | ] 14 | test_pipeline = [ 15 | dict(type="LoadImageFromFile"), 16 | dict( 17 | type="MultiScaleFlipAug", 18 | img_scale=(1333, 800), 19 | flip=False, 20 | transforms=[ 21 | dict(type="Resize", keep_ratio=True), 22 | dict(type="RandomFlip"), 23 | dict(type="Normalize", **img_norm_cfg), 24 | dict(type="Pad", size_divisor=32), 25 | dict(type="DefaultFormatBundle"), 26 | dict(type="Collect", keys=["img"]), 27 | ], 28 | ), 29 | ] 30 | data = dict( 31 | samples_per_gpu=2, 32 | workers_per_gpu=2, 33 | train=dict( 34 | type=dataset_type, 35 | ann_file=data_root + "annotations/instances_train2017.json", 36 | img_prefix=data_root + "train2017/", 37 | pipeline=train_pipeline, 38 | ), 39 | val=dict( 40 | type=dataset_type, 41 | ann_file=data_root + "annotations/instances_val2017.json", 42 | img_prefix=data_root + "val2017/", 43 | pipeline=test_pipeline, 44 | ), 45 | test=dict( 46 | type=dataset_type, 47 | ann_file=data_root + "annotations/instances_val2017.json", 48 | img_prefix=data_root + "val2017/", 49 | pipeline=test_pipeline, 50 | ), 51 | ) 52 | evaluation = dict(metric=["bbox", "segm"]) 53 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_cascade_mask_rcnn/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type="TextLoggerHook"), 7 | # dict(type='TensorboardLoggerHook') 8 | ], 9 | ) 10 | # yapf:enable 11 | dist_params = dict(backend="nccl") 12 | log_level = "INFO" 13 | load_from = None 14 | resume_from = None 15 | workflow = [("train", 1)] 16 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_cascade_mask_rcnn/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy="step", warmup="linear", warmup_iters=500, warmup_ratio=0.001, step=[8, 11]) 6 | total_epochs = 12 7 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/coco_detection.py: -------------------------------------------------------------------------------- 1 | dataset_type = "CocoDataset" 2 | data_root = "data/coco/" 3 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | train_pipeline = [ 5 | dict(type="LoadImageFromFile"), 6 | dict(type="LoadAnnotations", with_bbox=True), 7 | dict(type="Resize", img_scale=(1333, 800), keep_ratio=True), 8 | dict(type="RandomFlip", flip_ratio=0.5), 9 | dict(type="Normalize", **img_norm_cfg), 10 | dict(type="Pad", size_divisor=32), 11 | dict(type="DefaultFormatBundle"), 12 | dict(type="Collect", keys=["img", "gt_bboxes", "gt_labels"]), 13 | ] 14 | test_pipeline = [ 15 | dict(type="LoadImageFromFile"), 16 | dict( 17 | type="MultiScaleFlipAug", 18 | img_scale=(1333, 800), 19 | flip=False, 20 | transforms=[ 21 | dict(type="Resize", keep_ratio=True), 22 | dict(type="RandomFlip"), 23 | dict(type="Normalize", **img_norm_cfg), 24 | dict(type="Pad", size_divisor=32), 25 | dict(type="DefaultFormatBundle"), 26 | dict(type="Collect", keys=["img"]), 27 | ], 28 | ), 29 | ] 30 | data = dict( 31 | samples_per_gpu=2, 32 | workers_per_gpu=2, 33 | train=dict( 34 | type=dataset_type, 35 | ann_file=data_root + "annotations/instances_train2017.json", 36 | img_prefix=data_root + "train2017/", 37 | pipeline=train_pipeline, 38 | ), 39 | val=dict( 40 | type=dataset_type, 41 | ann_file=data_root + "annotations/instances_val2017.json", 42 | img_prefix=data_root + "val2017/", 43 | pipeline=test_pipeline, 44 | ), 45 | test=dict( 46 | type=dataset_type, 47 | ann_file=data_root + "annotations/instances_val2017.json", 48 | img_prefix=data_root + "val2017/", 49 | pipeline=test_pipeline, 50 | ), 51 | ) 52 | evaluation = dict(interval=1, metric="bbox") 53 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type="TextLoggerHook"), 7 | # dict(type='TensorboardLoggerHook') 8 | ], 9 | ) 10 | # yapf:enable 11 | dist_params = dict(backend="nccl") 12 | log_level = "INFO" 13 | load_from = None 14 | resume_from = None 15 | workflow = [("train", 1)] 16 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/retinanet_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="RetinaNet", 4 | pretrained="torchvision://resnet50", 5 | backbone=dict( 6 | type="ResNet", 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type="BN", requires_grad=True), 12 | norm_eval=True, 13 | style="pytorch", 14 | ), 15 | neck=dict( 16 | type="FPN", 17 | in_channels=[256, 512, 1024, 2048], 18 | out_channels=256, 19 | start_level=1, 20 | add_extra_convs="on_input", 21 | num_outs=5, 22 | ), 23 | bbox_head=dict( 24 | type="RetinaHead", 25 | num_classes=80, 26 | in_channels=256, 27 | stacked_convs=4, 28 | feat_channels=256, 29 | anchor_generator=dict( 30 | type="AnchorGenerator", 31 | octave_base_scale=4, 32 | scales_per_octave=3, 33 | ratios=[0.5, 1.0, 2.0], 34 | strides=[8, 16, 32, 64, 128], 35 | ), 36 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 37 | loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), 38 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 39 | ), 40 | # training and testing settings 41 | train_cfg=dict( 42 | assigner=dict(type="MaxIoUAssigner", pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), 43 | allowed_border=-1, 44 | pos_weight=-1, 45 | debug=False, 46 | ), 47 | test_cfg=dict( 48 | nms_pre=1000, min_bbox_size=0, score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100 49 | ), 50 | ) 51 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/retinanet_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ["retinanet_r50_fpn.py", "coco_detection.py", "schedule_1x.py", "default_runtime.py"] 2 | # optimizer 3 | optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) 4 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/retinanet_r50_fpn_1x_coco_v280.py: -------------------------------------------------------------------------------- 1 | _base_ = ["retinanet_r50_fpn_v280.py", "coco_detection.py", "schedule_1x.py", "default_runtime.py"] 2 | # optimizer 3 | optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) 4 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/retinanet_r50_fpn_v280.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="RetinaNet", 4 | pretrained="torchvision://resnet50", 5 | backbone=dict( 6 | type="ResNet", 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type="BN", requires_grad=True), 12 | norm_eval=True, 13 | style="pytorch", 14 | ), 15 | neck=dict( 16 | type="FPN", 17 | in_channels=[256, 512, 1024, 2048], 18 | out_channels=256, 19 | start_level=1, 20 | add_extra_convs="on_input", 21 | num_outs=5, 22 | ), 23 | bbox_head=dict( 24 | type="RetinaHead", 25 | num_classes=80, 26 | in_channels=256, 27 | stacked_convs=4, 28 | feat_channels=256, 29 | anchor_generator=dict( 30 | type="AnchorGenerator", 31 | octave_base_scale=4, 32 | scales_per_octave=3, 33 | ratios=[0.5, 1.0, 2.0], 34 | strides=[8, 16, 32, 64, 128], 35 | ), 36 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 37 | loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), 38 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 39 | ), 40 | ) 41 | # training and testing settings 42 | train_cfg = dict( 43 | assigner=dict(type="MaxIoUAssigner", pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), 44 | allowed_border=-1, 45 | pos_weight=-1, 46 | debug=False, 47 | ) 48 | test_cfg = dict(nms_pre=1000, min_bbox_size=0, score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100) 49 | -------------------------------------------------------------------------------- /tests/data/models/mmdet_retinanet/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy="step", warmup="linear", warmup_iters=500, warmup_ratio=0.001, step=[8, 11]) 6 | total_epochs = 12 7 | -------------------------------------------------------------------------------- /tests/data/models/torchvision/fasterrcnn_resnet50_fpn.yaml: -------------------------------------------------------------------------------- 1 | model_name: fasterrcnn_resnet50_fpn 2 | num_classes: 91 -------------------------------------------------------------------------------- /tests/data/models/torchvision/ssd300_vgg16.yaml: -------------------------------------------------------------------------------- 1 | model_name: ssd300_vgg16 2 | num_classes: 91 -------------------------------------------------------------------------------- /tests/data/small-vehicles1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/sahi/307cba031ee3ab8f540caf87a5a503d6939e80fa/tests/data/small-vehicles1.jpeg -------------------------------------------------------------------------------- /tests/test_annotation.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Fatih C Akyon, 2020. 3 | 4 | import logging 5 | import unittest 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class TestAnnotation(unittest.TestCase): 11 | def test_bounding_box(self): 12 | from sahi.annotation import BoundingBox 13 | 14 | bbox_minmax = [30.0, 30.0, 100.0, 150.0] 15 | shift_amount = [50, 40] 16 | 17 | bbox = BoundingBox(bbox_minmax, shift_amount=[0, 0]) 18 | expanded_bbox = bbox.get_expanded_box(ratio=0.1) 19 | 20 | bbox = BoundingBox(bbox_minmax, shift_amount=shift_amount) 21 | shifted_bbox = bbox.get_shifted_box() 22 | 23 | # compare 24 | self.assertEqual(expanded_bbox.to_xywh(), [18, 23, 94, 134]) 25 | self.assertEqual(expanded_bbox.to_xyxy(), [18, 23, 112, 157]) 26 | self.assertEqual(shifted_bbox.to_xyxy(), [80, 70, 150, 190]) 27 | 28 | def test_category(self): 29 | from sahi.annotation import Category 30 | 31 | category_id = 1 32 | category_name = "car" 33 | category = Category(id=category_id, name=category_name) 34 | self.assertEqual(category.id, category_id) 35 | self.assertEqual(category.name, category_name) 36 | 37 | def test_mask(self): 38 | from sahi.annotation import Mask 39 | 40 | coco_segmentation = [[1.0, 1.0, 325.0, 125.0, 250.0, 200.0, 5.0, 200.0]] 41 | full_shape_height, full_shape_width = 500, 600 42 | full_shape = [full_shape_height, full_shape_width] 43 | 44 | mask = Mask(segmentation=coco_segmentation, full_shape=full_shape) 45 | 46 | self.assertEqual(mask.full_shape_height, full_shape_height) 47 | self.assertEqual(mask.full_shape_width, full_shape_width) 48 | logger.debug(f"{type(mask.bool_mask[11, 2])=} {mask.bool_mask[11, 2]=}") 49 | self.assertEqual(mask.bool_mask[11, 2], True) 50 | 51 | def test_object_annotation(self): 52 | from sahi.annotation import ObjectAnnotation 53 | 54 | bbox = [100, 200, 150, 230] 55 | coco_bbox = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]] 56 | category_id = 2 57 | category_name = "car" 58 | shift_amount = [0, 0] 59 | image_height = 1080 60 | image_width = 1920 61 | full_shape = [image_height, image_width] 62 | 63 | object_annotation1 = ObjectAnnotation( 64 | bbox=bbox, 65 | category_id=category_id, 66 | category_name=category_name, 67 | shift_amount=shift_amount, 68 | full_shape=full_shape, 69 | ) 70 | 71 | object_annotation2 = ObjectAnnotation.from_coco_annotation_dict( 72 | annotation_dict={"bbox": coco_bbox, "category_id": category_id, "segmentation": []}, 73 | category_name=category_name, 74 | full_shape=full_shape, 75 | shift_amount=shift_amount, 76 | ) 77 | 78 | object_annotation3 = ObjectAnnotation.from_coco_bbox( 79 | bbox=coco_bbox, 80 | category_id=category_id, 81 | category_name=category_name, 82 | full_shape=full_shape, 83 | shift_amount=shift_amount, 84 | ) 85 | 86 | self.assertEqual(object_annotation1.bbox.minx, bbox[0]) 87 | self.assertEqual(object_annotation1.bbox.miny, bbox[1]) 88 | self.assertEqual(object_annotation1.bbox.maxx, bbox[2]) 89 | self.assertEqual(object_annotation1.bbox.maxy, bbox[3]) 90 | self.assertEqual(object_annotation1.category.id, category_id) 91 | self.assertEqual(object_annotation1.category.name, category_name) 92 | 93 | self.assertEqual(object_annotation2.bbox.minx, bbox[0]) 94 | self.assertEqual(object_annotation2.bbox.miny, bbox[1]) 95 | self.assertEqual(object_annotation2.bbox.maxx, bbox[2]) 96 | self.assertEqual(object_annotation2.bbox.maxy, bbox[3]) 97 | self.assertEqual(object_annotation2.category.id, category_id) 98 | self.assertEqual(object_annotation2.category.name, category_name) 99 | 100 | self.assertEqual(object_annotation3.bbox.minx, bbox[0]) 101 | self.assertEqual(object_annotation3.bbox.miny, bbox[1]) 102 | self.assertEqual(object_annotation3.bbox.maxx, bbox[2]) 103 | self.assertEqual(object_annotation3.bbox.maxy, bbox[3]) 104 | self.assertEqual(object_annotation3.category.id, category_id) 105 | self.assertEqual(object_annotation3.category.name, category_name) 106 | 107 | 108 | if __name__ == "__main__": 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /tests/test_cvutils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from sahi.utils.cv import ( 8 | Colors, 9 | apply_color_mask, 10 | exif_transpose, 11 | get_bbox_from_bool_mask, 12 | get_coco_segmentation_from_bool_mask, 13 | read_image, 14 | ) 15 | 16 | 17 | class TestCvUtils(unittest.TestCase): 18 | def test_hex_to_rgb(self): 19 | colors = Colors() 20 | self.assertEqual(colors.hex_to_rgb("#FF3838"), (255, 56, 56)) 21 | 22 | def test_hex_to_rgb_retrieve(self): 23 | colors = Colors() 24 | self.assertEqual(colors(0), (255, 56, 56)) 25 | 26 | @patch("sahi.utils.cv.cv2.cvtColor") 27 | @patch("sahi.utils.cv.cv2.imread") 28 | def test_read_image(self, mock_imread, mock_cvtColor): 29 | fake_image = "test.jpg" 30 | fake_image_val = np.array([[[10, 20, 30]]], dtype=np.uint8) 31 | fake_image_rbg_val = np.array([[[10, 20, 30]]], dtype=np.uint8) 32 | mock_imread.return_value = fake_image_val 33 | mock_cvtColor.return_value = fake_image_rbg_val 34 | 35 | result = read_image(fake_image) 36 | 37 | # mock_cv2.assert_called_once_with(fake_image) 38 | mock_imread.assert_called_once_with(fake_image) 39 | np.testing.assert_array_equal(result, fake_image_rbg_val) 40 | 41 | def test_apply_color_mask(self): 42 | image = np.array([[0, 1]], dtype=np.uint8) 43 | color = (255, 0, 0) 44 | 45 | expected_output = np.array([[[0, 0, 0], [255, 0, 0]]], dtype=np.uint8) 46 | 47 | result = apply_color_mask(image, color) 48 | 49 | np.testing.assert_array_equal(result, expected_output) 50 | 51 | def test_get_coco_segmentation_from_bool_mask_simple(self): 52 | mask = np.zeros((10, 10), dtype=bool) 53 | result = get_coco_segmentation_from_bool_mask(mask) 54 | self.assertEqual(result, []) 55 | 56 | def test_get_coco_segmentation_from_bool_mask_polygon(self): 57 | mask = np.zeros((10, 20), dtype=bool) 58 | mask[1:4, 1:4] = True 59 | mask[5:8, 5:8] = True 60 | result = get_coco_segmentation_from_bool_mask(mask) 61 | self.assertEqual(len(result), 2) 62 | 63 | def test_get_bbox_from_bool_mask(self): 64 | mask = np.array( 65 | [ 66 | [False, False, False], 67 | [False, True, True], 68 | [False, True, True], 69 | [False, False, False], 70 | ] 71 | ) 72 | expected_result = [1, 1, 2, 2] 73 | result = get_bbox_from_bool_mask(mask) 74 | self.assertEqual(result, expected_result) 75 | 76 | def test_exif_transpose_simple(self): 77 | test_image = Image.new("RGB", (100, 100), color="red") 78 | transposed = exif_transpose(test_image) 79 | self.assertEqual(transposed, test_image) 80 | 81 | def test_exif_transpose_non_standard(self): 82 | test_image = Image.new("RGB", (100, 100), color="red") 83 | exif = test_image.getexif() 84 | exif[0x0112] = 9 85 | test_image.info["exif"] = exif.tobytes() 86 | transposed = exif_transpose(test_image) 87 | self.assertEqual(transposed, test_image) 88 | 89 | 90 | if __name__ == "__main__": 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /tests/test_fileutils.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Fatih C Akyon, 2020. 3 | 4 | import unittest 5 | from unittest.mock import patch 6 | 7 | 8 | class TestFileUtils(unittest.TestCase): 9 | def test_list_files(self): 10 | from sahi.utils.file import list_files 11 | 12 | directory = "tests/data/coco_utils/" 13 | filepath_list = list_files(directory, contains=["json"], verbose=False) 14 | self.assertEqual(len(filepath_list), 11) 15 | 16 | def test_list_files_recursively(self): 17 | from sahi.utils.file import list_files_recursively 18 | 19 | directory = "tests/data/coco_utils/" 20 | relative_filepath_list, abs_filepath_list = list_files_recursively( 21 | directory, contains=["coco.json"], verbose=False 22 | ) 23 | self.assertEqual(len(relative_filepath_list), 7) 24 | self.assertEqual(len(abs_filepath_list), 7) 25 | 26 | def test_increment_path(self): 27 | from sahi.utils.file import increment_path 28 | 29 | with patch("sahi.utils.file.Path.exists", return_value=False): 30 | path = increment_path("test.txt") 31 | self.assertEqual(path, "test.txt") 32 | with patch("sahi.utils.file.Path.exists", return_value=True): 33 | path = increment_path("test.txt", exist_ok=False) 34 | self.assertEqual(path, "test.txt2") 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /tests/test_highlevelapi.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Fatih C Akyon, 2022. 3 | 4 | import unittest 5 | 6 | 7 | class TestHighLevelApi(unittest.TestCase): 8 | def test_bounding_box(self): 9 | from sahi import BoundingBox 10 | 11 | bbox_minmax = [30.0, 30.0, 100.0, 150.0] 12 | shift_amount = [50, 40] 13 | 14 | bbox = BoundingBox(bbox_minmax, shift_amount=[0, 0]) 15 | expanded_bbox = bbox.get_expanded_box(ratio=0.1) 16 | 17 | bbox = BoundingBox(bbox_minmax, shift_amount=shift_amount) 18 | shifted_bbox = bbox.get_shifted_box() 19 | 20 | # compare 21 | self.assertEqual(expanded_bbox.to_xywh(), [18, 23, 94, 134]) 22 | self.assertEqual(expanded_bbox.to_xyxy(), [18, 23, 112, 157]) 23 | self.assertEqual(shifted_bbox.to_xyxy(), [80, 70, 150, 190]) 24 | 25 | def test_category(self): 26 | from sahi import Category 27 | 28 | category_id = 1 29 | category_name = "car" 30 | category = Category(id=category_id, name=category_name) 31 | self.assertEqual(category.id, category_id) 32 | self.assertEqual(category.name, category_name) 33 | 34 | def test_mask(self): 35 | from sahi import Mask 36 | 37 | coco_segmentation = [[1.0, 1.0, 325.0, 125.0, 250.0, 200.0, 5.0, 200.0]] 38 | full_shape_height, full_shape_width = 500, 600 39 | full_shape = [full_shape_height, full_shape_width] 40 | 41 | mask = Mask(segmentation=coco_segmentation, full_shape=full_shape) 42 | 43 | self.assertEqual(mask.full_shape_height, full_shape_height) 44 | self.assertEqual(mask.full_shape_width, full_shape_width) 45 | self.assertEqual(mask.bool_mask[11, 2], True) 46 | 47 | def test_object_prediction(self): 48 | from sahi import ObjectPrediction 49 | 50 | bbox = [100, 200, 150, 230] 51 | coco_bbox = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]] 52 | category_id = 2 53 | category_name = "car" 54 | shift_amount = [0, 0] 55 | image_height = 1080 56 | image_width = 1920 57 | full_shape = [image_height, image_width] 58 | 59 | object_annotation1 = ObjectPrediction( 60 | bbox=bbox, 61 | category_id=category_id, 62 | category_name=category_name, 63 | shift_amount=shift_amount, 64 | full_shape=full_shape, 65 | ) 66 | 67 | object_annotation2 = ObjectPrediction.from_coco_annotation_dict( 68 | annotation_dict={"bbox": coco_bbox, "category_id": category_id, "segmentation": []}, 69 | category_name=category_name, 70 | full_shape=full_shape, 71 | shift_amount=shift_amount, 72 | ) 73 | 74 | object_annotation3 = ObjectPrediction.from_coco_bbox( 75 | bbox=coco_bbox, 76 | category_id=category_id, 77 | category_name=category_name, 78 | full_shape=full_shape, 79 | shift_amount=shift_amount, 80 | ) 81 | 82 | self.assertEqual(object_annotation1.bbox.minx, bbox[0]) 83 | self.assertEqual(object_annotation1.bbox.miny, bbox[1]) 84 | self.assertEqual(object_annotation1.bbox.maxx, bbox[2]) 85 | self.assertEqual(object_annotation1.bbox.maxy, bbox[3]) 86 | self.assertEqual(object_annotation1.category.id, category_id) 87 | self.assertEqual(object_annotation1.category.name, category_name) 88 | 89 | self.assertEqual(object_annotation2.bbox.minx, bbox[0]) 90 | self.assertEqual(object_annotation2.bbox.miny, bbox[1]) 91 | self.assertEqual(object_annotation2.bbox.maxx, bbox[2]) 92 | self.assertEqual(object_annotation2.bbox.maxy, bbox[3]) 93 | self.assertEqual(object_annotation2.category.id, category_id) 94 | self.assertEqual(object_annotation2.category.name, category_name) 95 | 96 | self.assertEqual(object_annotation3.bbox.minx, bbox[0]) 97 | self.assertEqual(object_annotation3.bbox.miny, bbox[1]) 98 | self.assertEqual(object_annotation3.bbox.maxx, bbox[2]) 99 | self.assertEqual(object_annotation3.bbox.maxy, bbox[3]) 100 | self.assertEqual(object_annotation3.category.id, category_id) 101 | self.assertEqual(object_annotation3.category.name, category_name) 102 | 103 | def test_detection_model(self): 104 | from sahi import DetectionModel 105 | 106 | MODEL_PATH = "model_path" 107 | IMAGE_SIZE = 640 108 | detection_model = DetectionModel(model_path="model_path", image_size=IMAGE_SIZE, load_at_init=False) 109 | self.assertEqual(detection_model.model_path, MODEL_PATH) 110 | self.assertEqual(detection_model.image_size, IMAGE_SIZE) 111 | 112 | 113 | if __name__ == "__main__": 114 | unittest.main() 115 | -------------------------------------------------------------------------------- /tests/test_postprocessutils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from sahi.postprocess.utils import ObjectPredictionList 6 | 7 | 8 | class TestPostprocessUtils(unittest.TestCase): 9 | def setUp(self): 10 | self.test_input = [ObjectPredictionList([1, 2, 3, 4])] 11 | 12 | def test_get_item_int(self): 13 | obj = self.test_input[0] 14 | self.assertEqual(obj[0].tolist(), 1) 15 | 16 | def test_len(self): 17 | obj = self.test_input[0] 18 | self.assertEqual(len(obj), 4) 19 | 20 | def test_extend(self): 21 | obj = self.test_input[0] 22 | obj.extend(ObjectPredictionList([torch.randn(1, 2, 3, 4)])) 23 | self.assertEqual(len(obj), 5) 24 | 25 | def test_tostring(self): 26 | obj = self.test_input[0] 27 | self.assertEqual(str(obj), str([1, 2, 3, 4])) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/test_prediction.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from sahi.prediction import PredictionScore 6 | 7 | 8 | class TestPrediction(unittest.TestCase): 9 | def test_prediction_score(self): 10 | prediction_score = PredictionScore(np.array(0.6)) 11 | self.assertEqual(type(prediction_score.value), float) 12 | self.assertEqual(prediction_score.is_greater_than_threshold(0.5), True) 13 | self.assertEqual(prediction_score.is_greater_than_threshold(0.7), False) 14 | 15 | 16 | if __name__ == "__main__": 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /tests/test_rtdetr.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Fatih C Akyon (2020), Devrim Çavuşoğlu (2024). 3 | 4 | import unittest 5 | 6 | from sahi.prediction import ObjectPrediction 7 | from sahi.utils.cv import read_image 8 | from sahi.utils.rtdetr import RTDETRTestConstants, download_rtdetrl_model 9 | 10 | MODEL_DEVICE = "cpu" 11 | CONFIDENCE_THRESHOLD = 0.3 12 | IMAGE_SIZE = 640 13 | 14 | 15 | class TestRTDetrDetectionModel(unittest.TestCase): 16 | def test_load_model(self): 17 | from sahi.models.rtdetr import RTDetrDetectionModel 18 | 19 | download_rtdetrl_model() 20 | 21 | rtdetr_detection_model = RTDetrDetectionModel( 22 | model_path=RTDETRTestConstants.RTDETRL_MODEL_PATH, 23 | confidence_threshold=CONFIDENCE_THRESHOLD, 24 | device=MODEL_DEVICE, 25 | category_remapping=None, 26 | load_at_init=True, 27 | ) 28 | 29 | self.assertNotEqual(rtdetr_detection_model.model, None) 30 | 31 | def test_set_model(self): 32 | from ultralytics import RTDETR 33 | 34 | from sahi.models.rtdetr import RTDetrDetectionModel 35 | 36 | download_rtdetrl_model() 37 | 38 | rtdetr_model = RTDETR(RTDETRTestConstants.RTDETRL_MODEL_PATH) 39 | 40 | rtdetr_detection_model = RTDetrDetectionModel( 41 | model=rtdetr_model, 42 | confidence_threshold=CONFIDENCE_THRESHOLD, 43 | device=MODEL_DEVICE, 44 | category_remapping=None, 45 | load_at_init=True, 46 | ) 47 | 48 | self.assertNotEqual(rtdetr_detection_model.model, None) 49 | 50 | def test_perform_inference(self): 51 | from sahi.models.rtdetr import RTDetrDetectionModel 52 | 53 | # init model 54 | download_rtdetrl_model() 55 | 56 | rtdetr_detection_model = RTDetrDetectionModel( 57 | model_path=RTDETRTestConstants.RTDETRL_MODEL_PATH, 58 | confidence_threshold=CONFIDENCE_THRESHOLD, 59 | device=MODEL_DEVICE, 60 | category_remapping=None, 61 | load_at_init=True, 62 | image_size=IMAGE_SIZE, 63 | ) 64 | 65 | # prepare image 66 | image_path = "tests/data/small-vehicles1.jpeg" 67 | image = read_image(image_path) 68 | 69 | # perform inference 70 | rtdetr_detection_model.perform_inference(image) 71 | original_predictions = rtdetr_detection_model.original_predictions 72 | 73 | boxes = original_predictions 74 | assert boxes is not None 75 | 76 | # find box of first car detection with conf greater than 0.5 77 | for box in boxes[0]: # type: ignore 78 | if box[5].item() == 2: # if category car 79 | if box[4].item() > 0.5: 80 | break 81 | 82 | # compare 83 | desired_bbox = [321, 322, 384, 362] 84 | predicted_bbox = list(map(round, box[:4].tolist())) 85 | margin = 2 86 | for ind, point in enumerate(predicted_bbox): 87 | assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin 88 | self.assertEqual(len(rtdetr_detection_model.category_names), 80) 89 | for box in boxes[0]: # type: ignore 90 | self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD) 91 | 92 | def test_convert_original_predictions(self): 93 | from sahi.models.rtdetr import RTDetrDetectionModel 94 | 95 | # init model 96 | download_rtdetrl_model() 97 | 98 | rtdetr_detection_model = RTDetrDetectionModel( 99 | model_path=RTDETRTestConstants.RTDETRL_MODEL_PATH, 100 | confidence_threshold=CONFIDENCE_THRESHOLD, 101 | device=MODEL_DEVICE, 102 | category_remapping=None, 103 | load_at_init=True, 104 | image_size=IMAGE_SIZE, 105 | ) 106 | 107 | # prepare image 108 | image_path = "tests/data/small-vehicles1.jpeg" 109 | image = read_image(image_path) 110 | 111 | # get raw predictions for reference 112 | original_results = rtdetr_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes 113 | num_results = len(original_results) 114 | 115 | # perform inference 116 | rtdetr_detection_model.perform_inference(image) 117 | 118 | # convert predictions to ObjectPrediction list 119 | rtdetr_detection_model.convert_original_predictions() 120 | object_prediction_list = rtdetr_detection_model.object_prediction_list 121 | 122 | # compare 123 | self.assertEqual(len(object_prediction_list), num_results) 124 | 125 | # loop through predictions and check that they are equal 126 | for i in range(num_results): 127 | desired_bbox = [ 128 | original_results[i].xyxy[0][0], 129 | original_results[i].xyxy[0][1], 130 | original_results[i].xywh[0][2], 131 | original_results[i].xywh[0][3], 132 | ] 133 | desired_cat_id = int(original_results[i].cls[0]) 134 | objectprd = object_prediction_list[i] 135 | assert isinstance(objectprd, ObjectPrediction) 136 | self.assertEqual(objectprd.category.id, desired_cat_id) 137 | predicted_bbox = objectprd.bbox.to_xywh() 138 | margin = 2 139 | for ind, point in enumerate(predicted_bbox): 140 | assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin 141 | for object_prediction in object_prediction_list: 142 | assert isinstance(object_prediction, ObjectPrediction) 143 | self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD) 144 | 145 | 146 | if __name__ == "__main__": 147 | unittest.main() 148 | -------------------------------------------------------------------------------- /tests/test_torchutils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from sahi.utils.torch import empty_cuda_cache, to_float_tensor, torch_to_numpy 7 | 8 | 9 | class TestTorchUtils(unittest.TestCase): 10 | def test_empty_cuda_cache(self): 11 | if torch.cuda.is_available(): 12 | self.assertIsNone(empty_cuda_cache()) 13 | 14 | def test_to_float_tensor(self): 15 | img = to_float_tensor(np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)) 16 | self.assertEqual(img.shape, (3, 10, 10)) 17 | 18 | def test_torch_to_numpy(self): 19 | img_t = torch.tensor(np.random.rand(3, 10, 10)) 20 | img = torch_to_numpy(img_t) 21 | self.assertEqual(img.shape, (10, 10, 3)) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/test_yolov8onnx.py: -------------------------------------------------------------------------------- 1 | # OBSS SAHI Tool 2 | # Code written by Karl-Joan Alesma, 2023 3 | 4 | import unittest 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from sahi.utils.yolov8onnx import Yolov8ONNXTestConstants, download_yolov8n_onnx_model 10 | 11 | MODEL_DEVICE = "cpu" 12 | CONFIDENCE_THRESHOLD = 0.3 13 | IOU_THRESHOLD = 0.7 14 | IMAGE_SIZE = 640 15 | 16 | 17 | class TestYolov8OnnxDetectionModel(unittest.TestCase): 18 | def test_load_model(self): 19 | from sahi.models.yolov8onnx import Yolov8OnnxDetectionModel 20 | 21 | download_yolov8n_onnx_model() 22 | 23 | yolov8_onnx_detection_model = Yolov8OnnxDetectionModel( 24 | model_path=Yolov8ONNXTestConstants.YOLOV8N_ONNX_MODEL_PATH, 25 | confidence_threshold=CONFIDENCE_THRESHOLD, 26 | iou_threshold=IOU_THRESHOLD, 27 | device=MODEL_DEVICE, 28 | category_mapping={"0": "something"}, 29 | load_at_init=False, 30 | ) 31 | 32 | # Test setting options for onnxruntime 33 | yolov8_onnx_detection_model.load_model({"enable_mem_pattern": False}) 34 | 35 | self.assertNotEqual(yolov8_onnx_detection_model.model, None) 36 | 37 | def test_set_model(self): 38 | import onnxruntime 39 | 40 | from sahi.models.yolov8onnx import Yolov8OnnxDetectionModel 41 | 42 | download_yolov8n_onnx_model() 43 | 44 | yolo_model = onnxruntime.InferenceSession(Yolov8ONNXTestConstants.YOLOV8N_ONNX_MODEL_PATH) 45 | 46 | yolov8_onnx_detection_model = Yolov8OnnxDetectionModel( 47 | model=yolo_model, 48 | confidence_threshold=CONFIDENCE_THRESHOLD, 49 | iou_threshold=IOU_THRESHOLD, 50 | device=MODEL_DEVICE, 51 | category_mapping={"0": "something"}, 52 | load_at_init=True, 53 | ) 54 | 55 | self.assertNotEqual(yolov8_onnx_detection_model.model, None) 56 | 57 | def test_perform_inference(self): 58 | from sahi.models.yolov8onnx import Yolov8OnnxDetectionModel 59 | 60 | # Init model 61 | download_yolov8n_onnx_model() 62 | 63 | yolov8_onnx_detection_model = Yolov8OnnxDetectionModel( 64 | model_path=Yolov8ONNXTestConstants.YOLOV8N_ONNX_MODEL_PATH, 65 | confidence_threshold=CONFIDENCE_THRESHOLD, 66 | iou_threshold=IOU_THRESHOLD, 67 | device=MODEL_DEVICE, 68 | category_mapping={"0": "something"}, 69 | load_at_init=True, 70 | image_size=IMAGE_SIZE, 71 | ) 72 | 73 | # Prepare image 74 | image_path = "tests/data/small-vehicles1.jpeg" 75 | image = cv2.imread(image_path) 76 | 77 | # Perform inference 78 | yolov8_onnx_detection_model.perform_inference(image) 79 | original_predictions = yolov8_onnx_detection_model.original_predictions 80 | assert original_predictions 81 | 82 | boxes = original_predictions[0] 83 | 84 | # Find most confident bbox for car 85 | best_box_index = np.argmax(boxes[boxes[:, 5] == 2][:, 4]) 86 | best_bbox = boxes[best_box_index] 87 | 88 | # Compare 89 | desired_bbox = [603, 239, 629, 259] 90 | predicted_bbox = best_bbox.tolist() 91 | margin = 2 92 | 93 | for ind, point in enumerate(predicted_bbox[:4]): 94 | assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin 95 | 96 | for box in boxes[0]: # pyright: ignore [reportGeneralTypeIssues] 97 | self.assertGreaterEqual(predicted_bbox[4], CONFIDENCE_THRESHOLD) 98 | 99 | 100 | if __name__ == "__main__": 101 | unittest.main() 102 | --------------------------------------------------------------------------------