├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── README.zh-CN.md ├── mkdocs.yml ├── requirements.txt ├── setup.cfg ├── setup.py ├── ultralytics ├── __init__.py ├── cfg │ ├── __init__.py │ ├── default.yaml │ ├── models │ │ ├── README.md │ │ ├── rt-detr │ │ │ ├── rtdetr-l.yaml │ │ │ └── rtdetr-x.yaml │ │ ├── v3 │ │ │ ├── yolov3-spp.yaml │ │ │ ├── yolov3-tiny.yaml │ │ │ └── yolov3.yaml │ │ ├── v5 │ │ │ ├── yolov5-p6.yaml │ │ │ └── yolov5.yaml │ │ ├── v6 │ │ │ └── yolov6.yaml │ │ └── v8 │ │ │ ├── yolov8-cls.yaml │ │ │ ├── yolov8-p2.yaml │ │ │ ├── yolov8-p6.yaml │ │ │ ├── yolov8-pose-p6.yaml │ │ │ ├── yolov8-pose.yaml │ │ │ ├── yolov8-rtdetr.yaml │ │ │ ├── yolov8-seg.yaml │ │ │ └── yolov8.yaml │ └── trackers │ │ ├── botsort.yaml │ │ └── bytetrack.yaml ├── data │ ├── __init__.py │ ├── annotator.py │ ├── augment.py │ ├── base.py │ ├── build.py │ ├── converter.py │ ├── dataset.py │ ├── loaders.py │ ├── scripts │ │ ├── download_weights.sh │ │ ├── get_coco.sh │ │ ├── get_coco128.sh │ │ └── get_imagenet.sh │ └── utils.py ├── engine │ ├── __init__.py │ ├── exporter.py │ ├── model.py │ ├── predictor.py │ ├── results.py │ ├── trainer.py │ └── validator.py ├── hub │ ├── __init__.py │ ├── auth.py │ ├── session.py │ └── utils.py ├── models │ ├── __init__.py │ ├── fastsam │ │ ├── __init__.py │ │ ├── model.py │ │ ├── predict.py │ │ ├── prompt.py │ │ ├── utils.py │ │ └── val.py │ ├── nas │ │ ├── __init__.py │ │ ├── model.py │ │ ├── predict.py │ │ └── val.py │ ├── rtdetr │ │ ├── __init__.py │ │ ├── model.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py │ ├── sam │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── build.py │ │ ├── model.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── decoders.py │ │ │ ├── encoders.py │ │ │ ├── sam.py │ │ │ ├── tiny_encoder.py │ │ │ └── transformer.py │ │ └── predict.py │ ├── utils │ │ ├── __init__.py │ │ ├── loss.py │ │ └── ops.py │ └── yolo │ │ ├── __init__.py │ │ ├── classify │ │ ├── __init__.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py │ │ ├── detect │ │ ├── __init__.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py │ │ ├── model.py │ │ ├── pose │ │ ├── __init__.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py │ │ └── segment │ │ ├── __init__.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py ├── nn │ ├── __init__.py │ ├── autobackend.py │ ├── modules │ │ ├── __init__.py │ │ ├── block.py │ │ ├── conv.py │ │ ├── head.py │ │ ├── transformer.py │ │ └── utils.py │ └── tasks.py ├── trackers │ ├── README.md │ ├── __init__.py │ ├── basetrack.py │ ├── bot_sort.py │ ├── byte_tracker.py │ ├── track.py │ └── utils │ │ ├── __init__.py │ │ ├── gmc.py │ │ ├── kalman_filter.py │ │ └── matching.py ├── utils │ ├── __init__.py │ ├── autobatch.py │ ├── benchmarks.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── clearml.py │ │ ├── comet.py │ │ ├── dvc.py │ │ ├── hub.py │ │ ├── mlflow.py │ │ ├── neptune.py │ │ ├── raytune.py │ │ ├── tensorboard.py │ │ └── wb.py │ ├── checks.py │ ├── dist.py │ ├── downloads.py │ ├── errors.py │ ├── files.py │ ├── instance.py │ ├── loss.py │ ├── metrics.py │ ├── ops.py │ ├── patches.py │ ├── plotting.py │ ├── tal.py │ ├── torch_utils.py │ └── tuner.py └── yolo │ ├── __init__.py │ ├── cfg │ └── __init__.py │ ├── data │ └── __init__.py │ ├── engine │ └── __init__.py │ ├── utils │ └── __init__.py │ └── v8 │ └── __init__.py └── y_prune ├── prune1.py ├── prune2.py └── y_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # Profiling 85 | *.pclprof 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | .idea 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | mkdocs_github_authors.yaml 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # datasets and projects 137 | datasets/ 138 | runs/ 139 | wandb/ 140 | .DS_Store 141 | 142 | # Neural Network weights ----------------------------------------------------------------------------------------------- 143 | weights/ 144 | *.weights 145 | *.pt 146 | *.pb 147 | *.onnx 148 | *.engine 149 | *.mlmodel 150 | *.mlpackage 151 | *.torchscript 152 | *.tflite 153 | *.h5 154 | *_saved_model/ 155 | *_web_model/ 156 | *_openvino_model/ 157 | *_paddle_model/ 158 | 159 | # Autogenerated files for tests 160 | /ultralytics/assets/ 161 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md 3 | 4 | exclude: 'docs/' 5 | # Define bot property if installed via https://github.com/marketplace/pre-commit-ci 6 | ci: 7 | autofix_prs: true 8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 9 | autoupdate_schedule: monthly 10 | # submodules: true 11 | 12 | repos: 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v4.4.0 15 | hooks: 16 | - id: end-of-file-fixer 17 | - id: trailing-whitespace 18 | - id: check-case-conflict 19 | # - id: check-yaml 20 | - id: check-docstring-first 21 | - id: double-quote-string-fixer 22 | - id: detect-private-key 23 | 24 | - repo: https://github.com/asottile/pyupgrade 25 | rev: v3.10.1 26 | hooks: 27 | - id: pyupgrade 28 | name: Upgrade code 29 | 30 | - repo: https://github.com/PyCQA/isort 31 | rev: 5.12.0 32 | hooks: 33 | - id: isort 34 | name: Sort imports 35 | 36 | - repo: https://github.com/google/yapf 37 | rev: v0.40.0 38 | hooks: 39 | - id: yapf 40 | name: YAPF formatting 41 | 42 | - repo: https://github.com/executablebooks/mdformat 43 | rev: 0.7.16 44 | hooks: 45 | - id: mdformat 46 | name: MD formatting 47 | additional_dependencies: 48 | - mdformat-gfm 49 | - mdformat-black 50 | # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md" 51 | 52 | - repo: https://github.com/PyCQA/flake8 53 | rev: 6.1.0 54 | hooks: 55 | - id: flake8 56 | name: PEP8 57 | 58 | - repo: https://github.com/codespell-project/codespell 59 | rev: v2.2.5 60 | hooks: 61 | - id: codespell 62 | args: 63 | - --ignore-words-list=crate,nd,strack,dota,ane,segway,fo 64 | 65 | # - repo: https://github.com/asottile/yesqa 66 | # rev: v1.4.0 67 | # hooks: 68 | # - id: yesqa 69 | 70 | # - repo: https://github.com/asottile/dead 71 | # rev: v1.5.0 72 | # hooks: 73 | # - id: dead 74 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | preferred-citation: 3 | type: software 4 | message: If you use this software, please cite it as below. 5 | authors: 6 | - family-names: Jocher 7 | given-names: Glenn 8 | orcid: "https://orcid.org/0000-0001-5950-6979" 9 | - family-names: Chaurasia 10 | given-names: Ayush 11 | orcid: "https://orcid.org/0000-0002-7603-6750" 12 | - family-names: Qiu 13 | given-names: Jing 14 | orcid: "https://orcid.org/0000-0003-3783-7069" 15 | title: "YOLO by Ultralytics" 16 | version: 8.0.0 17 | # doi: 10.5281/zenodo.3908559 # TODO 18 | date-released: 2023-1-10 19 | license: AGPL-3.0 20 | url: "https://github.com/ultralytics/ultralytics" 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to YOLOv8 🚀 2 | 3 | We love your input! We want to make contributing to YOLOv8 as easy and transparent as possible, whether it's: 4 | 5 | - Reporting a bug 6 | - Discussing the current state of the code 7 | - Submitting a fix 8 | - Proposing a new feature 9 | - Becoming a maintainer 10 | 11 | YOLOv8 works so well due to our combined community effort, and for every small improvement you contribute you will be 12 | helping push the frontiers of what's possible in AI 😃! 13 | 14 | ## Submitting a Pull Request (PR) 🛠️ 15 | 16 | Submitting a PR is easy! This example shows how to submit a PR for updating `requirements.txt` in 4 steps: 17 | 18 | ### 1. Select File to Update 19 | 20 | Select `requirements.txt` to update by clicking on it in GitHub. 21 | 22 |

PR_step1

23 | 24 | ### 2. Click 'Edit this file' 25 | 26 | Button is in top-right corner. 27 | 28 |

PR_step2

29 | 30 | ### 3. Make Changes 31 | 32 | Change `matplotlib` version from `3.2.2` to `3.3`. 33 | 34 |

PR_step3

35 | 36 | ### 4. Preview Changes and Submit PR 37 | 38 | Click on the **Preview changes** tab to verify your updates. At the bottom of the screen select 'Create a **new branch** 39 | for this commit', assign your branch a descriptive name such as `fix/matplotlib_version` and click the green **Propose 40 | changes** button. All done, your PR is now submitted to YOLOv8 for review and approval 😃! 41 | 42 |

PR_step4

43 | 44 | ### PR recommendations 45 | 46 | To allow your work to be integrated as seamlessly as possible, we advise you to: 47 | 48 | - ✅ Verify your PR is **up-to-date** with `ultralytics/ultralytics` `main` branch. If your PR is behind you can update 49 | your code by clicking the 'Update branch' button or by running `git pull` and `git merge main` locally. 50 | 51 |

Screenshot 2022-08-29 at 22 47 15

52 | 53 | - ✅ Verify all YOLOv8 Continuous Integration (CI) **checks are passing**. 54 | 55 |

Screenshot 2022-08-29 at 22 47 03

56 | 57 | - ✅ Reduce changes to the absolute **minimum** required for your bug fix or feature addition. _"It is not daily increase 58 | but daily decrease, hack away the unessential. The closer to the source, the less wastage there is."_ — Bruce Lee 59 | 60 | ### Docstrings 61 | 62 | Not all functions or classes require docstrings but when they do, we 63 | follow [google-style docstrings format](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). 64 | Here is an example: 65 | 66 | ```python 67 | """ 68 | What the function does. Performs NMS on given detection predictions. 69 | 70 | Args: 71 | arg1: The description of the 1st argument 72 | arg2: The description of the 2nd argument 73 | 74 | Returns: 75 | What the function returns. Empty if nothing is returned. 76 | 77 | Raises: 78 | Exception Class: When and why this exception can be raised by the function. 79 | """ 80 | ``` 81 | 82 | ## Submitting a Bug Report 🐛 83 | 84 | If you spot a problem with YOLOv8 please submit a Bug Report! 85 | 86 | For us to start investigating a possible problem we need to be able to reproduce it ourselves first. We've created a few 87 | short guidelines below to help users provide what we need in order to get started. 88 | 89 | When asking a question, people will be better able to provide help if you provide **code** that they can easily 90 | understand and use to **reproduce** the problem. This is referred to by community members as creating 91 | a [minimum reproducible example](https://docs.ultralytics.com/help/minimum_reproducible_example/). Your code that reproduces 92 | the problem should be: 93 | 94 | - ✅ **Minimal** – Use as little code as possible that still produces the same problem 95 | - ✅ **Complete** – Provide **all** parts someone else needs to reproduce your problem in the question itself 96 | - ✅ **Reproducible** – Test the code you're about to provide to make sure it reproduces the problem 97 | 98 | In addition to the above requirements, for [Ultralytics](https://ultralytics.com/) to provide assistance your code 99 | should be: 100 | 101 | - ✅ **Current** – Verify that your code is up-to-date with current 102 | GitHub [main](https://github.com/ultralytics/ultralytics/tree/main) branch, and if necessary `git pull` or `git clone` 103 | a new copy to ensure your problem has not already been resolved by previous commits. 104 | - ✅ **Unmodified** – Your problem must be reproducible without any modifications to the codebase in this 105 | repository. [Ultralytics](https://ultralytics.com/) does not provide support for custom code ⚠️. 106 | 107 | If you believe your problem meets all of the above criteria, please close this issue and raise a new one using the 🐛 108 | **Bug Report** [template](https://github.com/ultralytics/ultralytics/issues/new/choose) and providing 109 | a [minimum reproducible example](https://docs.ultralytics.com/help/minimum_reproducible_example/) to help us better 110 | understand and diagnose your problem. 111 | 112 | ## License 113 | 114 | By contributing, you agree that your contributions will be licensed under 115 | the [AGPL-3.0 license](https://choosealicense.com/licenses/agpl-3.0/) 116 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | include requirements.txt 3 | include LICENSE 4 | include setup.py 5 | include ultralytics/assets/bus.jpg 6 | include ultralytics/assets/zidane.jpg 7 | recursive-include ultralytics *.yaml 8 | recursive-exclude __pycache__ * 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Ultralytics requirements 2 | # Usage: pip install -r requirements.txt 3 | 4 | # Base ---------------------------------------- 5 | matplotlib>=3.2.2 6 | numpy>=1.22.2 # pinned by Snyk to avoid a vulnerability 7 | opencv-python>=4.6.0 8 | pillow>=7.1.2 9 | pyyaml>=5.3.1 10 | requests>=2.23.0 11 | scipy>=1.4.1 12 | torch>=1.8.0 13 | torchvision>=0.9.0 14 | tqdm>=4.64.0 15 | 16 | # Logging ------------------------------------- 17 | # tensorboard>=2.13.0 18 | # dvclive>=2.12.0 19 | # clearml 20 | # comet 21 | 22 | # Plotting ------------------------------------ 23 | pandas>=1.1.4 24 | seaborn>=0.11.0 25 | 26 | # Export -------------------------------------- 27 | # coremltools>=7.0.b1 # CoreML export 28 | # onnx>=1.12.0 # ONNX export 29 | # onnxsim>=0.4.1 # ONNX simplifier 30 | # nvidia-pyindex # TensorRT export 31 | # nvidia-tensorrt # TensorRT export 32 | # scikit-learn==0.19.2 # CoreML quantization 33 | # tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos) 34 | # tflite-support 35 | # tensorflowjs>=3.9.0 # TF.js export 36 | # openvino-dev>=2023.0 # OpenVINO export 37 | 38 | # Extras -------------------------------------- 39 | psutil # system utilization 40 | py-cpuinfo # display CPU info 41 | # thop>=0.1.1 # FLOPs computation 42 | # ipython # interactive notebook 43 | # albumentations>=1.0.3 # training augmentations 44 | # pycocotools>=2.0.6 # COCO mAP 45 | # roboflow 46 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Project-wide configuration file, can be used for package metadata and other toll configurations 2 | # Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments 3 | # Local usage: pip install pre-commit, pre-commit run --all-files 4 | 5 | [metadata] 6 | license_files = LICENSE 7 | description_file = README.md 8 | 9 | [tool:pytest] 10 | norecursedirs = 11 | .git 12 | dist 13 | build 14 | addopts = 15 | --doctest-modules 16 | --durations=25 17 | --color=yes 18 | 19 | [coverage:run] 20 | source = ultralytics/ 21 | data_file = tests/.coverage 22 | omit = 23 | ultralytics/utils/callbacks/* 24 | 25 | [flake8] 26 | max-line-length = 120 27 | exclude = .tox,*.egg,build,temp 28 | select = E,W,F 29 | doctests = True 30 | verbose = 2 31 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 32 | format = pylint 33 | # see: https://www.flake8rules.com/ 34 | ignore = E731,F405,E402,W504,E501 35 | # E731: Do not assign a lambda expression, use a def 36 | # F405: name may be undefined, or defined from star imports: module 37 | # E402: module level import not at top of file 38 | # W504: line break after binary operator 39 | # E501: line too long 40 | # removed: 41 | # F401: module imported but unused 42 | # E231: missing whitespace after ‘,’, ‘;’, or ‘:’ 43 | # E127: continuation line over-indented for visual indent 44 | # F403: ‘from module import *’ used; unable to detect undefined names 45 | 46 | 47 | [isort] 48 | # https://pycqa.github.io/isort/docs/configuration/options.html 49 | line_length = 120 50 | # see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html 51 | multi_line_output = 0 52 | 53 | [yapf] 54 | based_on_style = pep8 55 | spaces_before_comment = 2 56 | COLUMN_LIMIT = 120 57 | COALESCE_BRACKETS = True 58 | SPACES_AROUND_POWER_OPERATOR = True 59 | SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True 60 | SPLIT_BEFORE_CLOSING_BRACKET = False 61 | SPLIT_BEFORE_FIRST_ARGUMENT = False 62 | # EACH_DICT_ENTRY_ON_SEPARATE_LINE = False 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import re 4 | from pathlib import Path 5 | 6 | import pkg_resources as pkg 7 | from setuptools import find_packages, setup 8 | 9 | # Settings 10 | FILE = Path(__file__).resolve() 11 | PARENT = FILE.parent # root directory 12 | README = (PARENT / 'README.md').read_text(encoding='utf-8') 13 | REQUIREMENTS = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements((PARENT / 'requirements.txt').read_text())] 14 | 15 | 16 | def get_version(): 17 | file = PARENT / 'ultralytics/__init__.py' 18 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding='utf-8'), re.M)[1] 19 | 20 | 21 | setup( 22 | name='ultralytics', # name of pypi package 23 | version=get_version(), # version of pypi package 24 | python_requires='>=3.8', 25 | license='AGPL-3.0', 26 | description=('Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, ' 27 | 'pose estimation and image classification.'), 28 | long_description=README, 29 | long_description_content_type='text/markdown', 30 | url='https://github.com/ultralytics/ultralytics', 31 | project_urls={ 32 | 'Bug Reports': 'https://github.com/ultralytics/ultralytics/issues', 33 | 'Funding': 'https://ultralytics.com', 34 | 'Source': 'https://github.com/ultralytics/ultralytics'}, 35 | author='Ultralytics', 36 | author_email='hello@ultralytics.com', 37 | packages=find_packages(), # required 38 | include_package_data=True, 39 | install_requires=REQUIREMENTS, 40 | extras_require={ 41 | 'dev': [ 42 | 'ipython', 43 | 'check-manifest', 44 | 'pytest', 45 | 'pytest-cov', 46 | 'coverage', 47 | 'mkdocs-material', 48 | 'mkdocstrings[python]', 49 | 'mkdocs-redirects', # for 301 redirects 50 | 'mkdocs-ultralytics-plugin>=0.0.25', # for meta descriptions and images, dates and authors 51 | ], 52 | 'export': [ 53 | 'coremltools>=7.0.b1', 54 | 'openvino-dev>=2023.0', 55 | 'tensorflowjs', # automatically installs tensorflow 56 | ], }, 57 | classifiers=[ 58 | 'Development Status :: 4 - Beta', 59 | 'Intended Audience :: Developers', 60 | 'Intended Audience :: Education', 61 | 'Intended Audience :: Science/Research', 62 | 'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)', 63 | 'Programming Language :: Python :: 3', 64 | 'Programming Language :: Python :: 3.8', 65 | 'Programming Language :: Python :: 3.9', 66 | 'Programming Language :: Python :: 3.10', 67 | 'Programming Language :: Python :: 3.11', 68 | 'Topic :: Software Development', 69 | 'Topic :: Scientific/Engineering', 70 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 71 | 'Topic :: Scientific/Engineering :: Image Recognition', 72 | 'Operating System :: POSIX :: Linux', 73 | 'Operating System :: MacOS', 74 | 'Operating System :: Microsoft :: Windows', ], 75 | keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics', 76 | entry_points={'console_scripts': ['yolo = ultralytics.cfg:entrypoint', 'ultralytics = ultralytics.cfg:entrypoint']}) 77 | -------------------------------------------------------------------------------- /ultralytics/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | __version__ = '8.0.155' 4 | 5 | from ultralytics.hub import start 6 | from ultralytics.models import RTDETR, SAM, YOLO 7 | from ultralytics.models.fastsam import FastSAM 8 | from ultralytics.models.nas import NAS 9 | from ultralytics.utils import SETTINGS as settings 10 | from ultralytics.utils.checks import check_yolo as checks 11 | from ultralytics.utils.downloads import download 12 | 13 | __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import 14 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/README.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration 4 | files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted 5 | and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image 6 | segmentation tasks. 7 | 8 | These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like 9 | instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, 10 | from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this 11 | directory provides a great starting point for your custom model development needs. 12 | 13 | To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've 14 | selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full 15 | details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free 16 | to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now! 17 | 18 | ### Usage 19 | 20 | Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command: 21 | 22 | ```bash 23 | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 24 | ``` 25 | 26 | They may also be used directly in a Python environment, and accepts the same 27 | [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: 28 | 29 | ```python 30 | from ultralytics import YOLO 31 | 32 | model = YOLO("model.yaml") # build a YOLOv8n model from scratch 33 | # YOLO("model.pt") use pre-trained model if available 34 | model.info() # display model information 35 | model.train(data="coco128.yaml", epochs=100) # train the model 36 | ``` 37 | 38 | ## Pre-trained Model Architectures 39 | 40 | Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information 41 | and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available. 42 | 43 | ## Contributing New Models 44 | 45 | If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing). 46 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/rt-detr/rtdetr-l.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | l: [1.00, 1.00, 1024] 9 | 10 | backbone: 11 | # [from, repeats, module, args] 12 | - [-1, 1, HGStem, [32, 48]] # 0-P2/4 13 | - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 14 | 15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 16 | - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 17 | 18 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 19 | - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut 20 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] 21 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 22 | 23 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 24 | - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 25 | 26 | head: 27 | - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 28 | - [-1, 1, AIFI, [1024, 8]] 29 | - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 30 | 31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 32 | - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 33 | - [[-2, -1], 1, Concat, [1]] 34 | - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 35 | - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 36 | 37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 38 | - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 39 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4 40 | - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 41 | 42 | - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 43 | - [[-1, 17], 1, Concat, [1]] # cat Y4 44 | - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 45 | 46 | - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 47 | - [[-1, 12], 1, Concat, [1]] # cat Y5 48 | - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 49 | 50 | - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 51 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/rt-detr/rtdetr-x.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | x: [1.00, 1.00, 2048] 9 | 10 | backbone: 11 | # [from, repeats, module, args] 12 | - [-1, 1, HGStem, [32, 64]] # 0-P2/4 13 | - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 14 | 15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 16 | - [-1, 6, HGBlock, [128, 512, 3]] 17 | - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 18 | 19 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 20 | - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut 21 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 22 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 23 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 24 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 25 | 26 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 27 | - [-1, 6, HGBlock, [512, 2048, 5, True, False]] 28 | - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 29 | 30 | head: 31 | - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 32 | - [-1, 1, AIFI, [2048, 8]] 33 | - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 34 | 35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 36 | - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 37 | - [[-2, -1], 1, Concat, [1]] 38 | - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 39 | - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 40 | 41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 42 | - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 43 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4 44 | - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 45 | 46 | - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 47 | - [[-1, 21], 1, Concat, [1]] # cat Y4 48 | - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 49 | 50 | - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 51 | - [[-1, 16], 1, Concat, [1]] # cat Y5 52 | - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 53 | 54 | - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 55 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v3/yolov3-spp.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # darknet53 backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [32, 3, 1]], # 0 13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 14 | [-1, 1, Bottleneck, [64]], 15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 16 | [-1, 2, Bottleneck, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 18 | [-1, 8, Bottleneck, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 20 | [-1, 8, Bottleneck, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 22 | [-1, 4, Bottleneck, [1024]], # 10 23 | ] 24 | 25 | # YOLOv3-SPP head 26 | head: 27 | [[-1, 1, Bottleneck, [1024, False]], 28 | [-1, 1, SPP, [512, [5, 9, 13]]], 29 | [-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [512, 1, 1]], 31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Bottleneck, [512, False]], 37 | [-1, 1, Bottleneck, [512, False]], 38 | [-1, 1, Conv, [256, 1, 1]], 39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) 40 | 41 | [-2, 1, Conv, [128, 1, 1]], 42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3 44 | [-1, 1, Bottleneck, [256, False]], 45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) 46 | 47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) 48 | ] 49 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v3/yolov3-tiny.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # YOLOv3-tiny backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [16, 3, 1]], # 0 13 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2 14 | [-1, 1, Conv, [32, 3, 1]], 15 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4 16 | [-1, 1, Conv, [64, 3, 1]], 17 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8 18 | [-1, 1, Conv, [128, 3, 1]], 19 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16 20 | [-1, 1, Conv, [256, 3, 1]], 21 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32 22 | [-1, 1, Conv, [512, 3, 1]], 23 | [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11 24 | [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12 25 | ] 26 | 27 | # YOLOv3-tiny head 28 | head: 29 | [[-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [256, 1, 1]], 31 | [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [128, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium) 37 | 38 | [[19, 15], 1, Detect, [nc]], # Detect(P4, P5) 39 | ] 40 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v3/yolov3.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # darknet53 backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [32, 3, 1]], # 0 13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 14 | [-1, 1, Bottleneck, [64]], 15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 16 | [-1, 2, Bottleneck, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 18 | [-1, 8, Bottleneck, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 20 | [-1, 8, Bottleneck, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 22 | [-1, 4, Bottleneck, [1024]], # 10 23 | ] 24 | 25 | # YOLOv3 head 26 | head: 27 | [[-1, 1, Bottleneck, [1024, False]], 28 | [-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [512, 1, 1]], 31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Bottleneck, [512, False]], 37 | [-1, 1, Bottleneck, [512, False]], 38 | [-1, 1, Conv, [256, 1, 1]], 39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) 40 | 41 | [-2, 1, Conv, [128, 1, 1]], 42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3 44 | [-1, 1, Bottleneck, [256, False]], 45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) 46 | 47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) 48 | ] 49 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v5/yolov5-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.33, 1.25, 1024] 13 | 14 | # YOLOv5 v6.0 backbone 15 | backbone: 16 | # [from, number, module, args] 17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 19 | [-1, 3, C3, [128]], 20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 21 | [-1, 6, C3, [256]], 22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 23 | [-1, 9, C3, [512]], 24 | [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 25 | [-1, 3, C3, [768]], 26 | [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 27 | [-1, 3, C3, [1024]], 28 | [-1, 1, SPPF, [1024, 5]], # 11 29 | ] 30 | 31 | # YOLOv5 v6.0 head 32 | head: 33 | [[-1, 1, Conv, [768, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P5 36 | [-1, 3, C3, [768, False]], # 15 37 | 38 | [-1, 1, Conv, [512, 1, 1]], 39 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 40 | [[-1, 6], 1, Concat, [1]], # cat backbone P4 41 | [-1, 3, C3, [512, False]], # 19 42 | 43 | [-1, 1, Conv, [256, 1, 1]], 44 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 45 | [[-1, 4], 1, Concat, [1]], # cat backbone P3 46 | [-1, 3, C3, [256, False]], # 23 (P3/8-small) 47 | 48 | [-1, 1, Conv, [256, 3, 2]], 49 | [[-1, 20], 1, Concat, [1]], # cat head P4 50 | [-1, 3, C3, [512, False]], # 26 (P4/16-medium) 51 | 52 | [-1, 1, Conv, [512, 3, 2]], 53 | [[-1, 16], 1, Concat, [1]], # cat head P5 54 | [-1, 3, C3, [768, False]], # 29 (P5/32-large) 55 | 56 | [-1, 1, Conv, [768, 3, 2]], 57 | [[-1, 12], 1, Concat, [1]], # cat head P6 58 | [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) 59 | 60 | [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) 61 | ] 62 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v5/yolov5.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.33, 1.25, 1024] 13 | 14 | # YOLOv5 v6.0 backbone 15 | backbone: 16 | # [from, number, module, args] 17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 19 | [-1, 3, C3, [128]], 20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 21 | [-1, 6, C3, [256]], 22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 23 | [-1, 9, C3, [512]], 24 | [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 25 | [-1, 3, C3, [1024]], 26 | [-1, 1, SPPF, [1024, 5]], # 9 27 | ] 28 | 29 | # YOLOv5 v6.0 head 30 | head: 31 | [[-1, 1, Conv, [512, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 6], 1, Concat, [1]], # cat backbone P4 34 | [-1, 3, C3, [512, False]], # 13 35 | 36 | [-1, 1, Conv, [256, 1, 1]], 37 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 38 | [[-1, 4], 1, Concat, [1]], # cat backbone P3 39 | [-1, 3, C3, [256, False]], # 17 (P3/8-small) 40 | 41 | [-1, 1, Conv, [256, 3, 2]], 42 | [[-1, 14], 1, Concat, [1]], # cat head P4 43 | [-1, 3, C3, [512, False]], # 20 (P4/16-medium) 44 | 45 | [-1, 1, Conv, [512, 3, 2]], 46 | [[-1, 10], 1, Concat, [1]], # cat head P5 47 | [-1, 3, C3, [1024, False]], # 23 (P5/32-large) 48 | 49 | [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) 50 | ] 51 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v6/yolov6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | activation: nn.ReLU() # (optional) model default activation function 7 | scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv6-3.0s backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 6, Conv, [128, 3, 1]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 12, Conv, [256, 3, 1]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 18, Conv, [512, 3, 1]] 25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 26 | - [-1, 6, Conv, [1024, 3, 1]] 27 | - [-1, 1, SPPF, [1024, 5]] # 9 28 | 29 | # YOLOv6-3.0s head 30 | head: 31 | - [-1, 1, Conv, [256, 1, 1]] 32 | - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] 33 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 34 | - [-1, 1, Conv, [256, 3, 1]] 35 | - [-1, 9, Conv, [256, 3, 1]] # 14 36 | 37 | - [-1, 1, Conv, [128, 1, 1]] 38 | - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] 39 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 40 | - [-1, 1, Conv, [128, 3, 1]] 41 | - [-1, 9, Conv, [128, 3, 1]] # 19 42 | 43 | - [-1, 1, Conv, [128, 3, 2]] 44 | - [[-1, 15], 1, Concat, [1]] # cat head P4 45 | - [-1, 1, Conv, [256, 3, 1]] 46 | - [-1, 9, Conv, [256, 3, 1]] # 23 47 | 48 | - [-1, 1, Conv, [256, 3, 2]] 49 | - [[-1, 10], 1, Concat, [1]] # cat head P5 50 | - [-1, 1, Conv, [512, 3, 1]] 51 | - [-1, 9, Conv, [512, 3, 1]] # 27 52 | 53 | - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) 54 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-cls.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify 3 | 4 | # Parameters 5 | nc: 1000 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.00, 1.25, 1024] 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | 27 | # YOLOv8.0n head 28 | head: 29 | - [-1, 1, Classify, [nc]] # Classify 30 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-p2.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0 backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0-p2 head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 39 | - [[-1, 2], 1, Concat, [1]] # cat backbone P2 40 | - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) 41 | 42 | - [-1, 1, Conv, [128, 3, 2]] 43 | - [[-1, 15], 1, Concat, [1]] # cat head P3 44 | - [-1, 3, C2f, [256]] # 21 (P3/8-small) 45 | 46 | - [-1, 1, Conv, [256, 3, 2]] 47 | - [[-1, 12], 1, Concat, [1]] # cat head P4 48 | - [-1, 3, C2f, [512]] # 24 (P4/16-medium) 49 | 50 | - [-1, 1, Conv, [512, 3, 2]] 51 | - [[-1, 9], 1, Concat, [1]] # cat head P5 52 | - [-1, 3, C2f, [1024]] # 27 (P5/32-large) 53 | 54 | - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) 55 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0x6 backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [768, True]] 26 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 27 | - [-1, 3, C2f, [1024, True]] 28 | - [-1, 1, SPPF, [1024, 5]] # 11 29 | 30 | # YOLOv8.0x6 head 31 | head: 32 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 33 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5 34 | - [-1, 3, C2, [768, False]] # 14 35 | 36 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 37 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 38 | - [-1, 3, C2, [512, False]] # 17 39 | 40 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 41 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 42 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small) 43 | 44 | - [-1, 1, Conv, [256, 3, 2]] 45 | - [[-1, 17], 1, Concat, [1]] # cat head P4 46 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) 47 | 48 | - [-1, 1, Conv, [512, 3, 2]] 49 | - [[-1, 14], 1, Concat, [1]] # cat head P5 50 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large) 51 | 52 | - [-1, 1, Conv, [768, 3, 2]] 53 | - [[-1, 11], 1, Concat, [1]] # cat head P6 54 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) 55 | 56 | - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) 57 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-pose-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose 3 | 4 | # Parameters 5 | nc: 1 # number of classes 6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv8.0x6 backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 3, C2f, [128, True]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 6, C2f, [256, True]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 6, C2f, [512, True]] 25 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 26 | - [-1, 3, C2f, [768, True]] 27 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 28 | - [-1, 3, C2f, [1024, True]] 29 | - [-1, 1, SPPF, [1024, 5]] # 11 30 | 31 | # YOLOv8.0x6 head 32 | head: 33 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 34 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5 35 | - [-1, 3, C2, [768, False]] # 14 36 | 37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 38 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 39 | - [-1, 3, C2, [512, False]] # 17 40 | 41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 42 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 43 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small) 44 | 45 | - [-1, 1, Conv, [256, 3, 2]] 46 | - [[-1, 17], 1, Concat, [1]] # cat head P4 47 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) 48 | 49 | - [-1, 1, Conv, [512, 3, 2]] 50 | - [[-1, 14], 1, Concat, [1]] # cat head P5 51 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large) 52 | 53 | - [-1, 1, Conv, [768, 3, 2]] 54 | - [[-1, 11], 1, Concat, [1]] # cat head P6 55 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) 56 | 57 | - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) 58 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-pose.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose 3 | 4 | # Parameters 5 | nc: 1 # number of classes 6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv8.0n backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 3, C2f, [128, True]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 6, C2f, [256, True]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 6, C2f, [512, True]] 25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 26 | - [-1, 3, C2f, [1024, True]] 27 | - [-1, 1, SPPF, [1024, 5]] # 9 28 | 29 | # YOLOv8.0n head 30 | head: 31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 32 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 33 | - [-1, 3, C2f, [512]] # 12 34 | 35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 36 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 37 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 38 | 39 | - [-1, 1, Conv, [256, 3, 2]] 40 | - [[-1, 12], 1, Concat, [1]] # cat head P4 41 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 42 | 43 | - [-1, 1, Conv, [512, 3, 2]] 44 | - [[-1, 9], 1, Concat, [1]] # cat head P5 45 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 46 | 47 | - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) 48 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-rtdetr.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs 9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs 10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs 11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs 12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8-seg.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /ultralytics/cfg/models/v8/yolov8.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs 9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs 10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs 11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs 12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /ultralytics/cfg/trackers/botsort.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT 3 | 4 | tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] 5 | track_high_thresh: 0.5 # threshold for the first association 6 | track_low_thresh: 0.1 # threshold for the second association 7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks 8 | track_buffer: 30 # buffer to calculate the time when to remove tracks 9 | match_thresh: 0.8 # threshold for matching tracks 10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) 11 | # mot20: False # for tracker evaluation(not used for now) 12 | 13 | # BoT-SORT settings 14 | cmc_method: sparseOptFlow # method of global motion compensation 15 | # ReID model related thresh (not supported yet) 16 | proximity_thresh: 0.5 17 | appearance_thresh: 0.25 18 | with_reid: False 19 | -------------------------------------------------------------------------------- /ultralytics/cfg/trackers/bytetrack.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack 3 | 4 | tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] 5 | track_high_thresh: 0.5 # threshold for the first association 6 | track_low_thresh: 0.1 # threshold for the second association 7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks 8 | track_buffer: 30 # buffer to calculate the time when to remove tracks 9 | match_thresh: 0.8 # threshold for matching tracks 10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) 11 | # mot20: False # for tracker evaluation(not used for now) 12 | -------------------------------------------------------------------------------- /ultralytics/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .base import BaseDataset 4 | from .build import build_dataloader, build_yolo_dataset, load_inference_source 5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset 6 | 7 | __all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset', 8 | 'build_dataloader', 'load_inference_source') 9 | -------------------------------------------------------------------------------- /ultralytics/data/annotator.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from pathlib import Path 4 | 5 | from ultralytics import SAM, YOLO 6 | 7 | 8 | def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): 9 | """ 10 | Automatically annotates images using a YOLO object detection model and a SAM segmentation model. 11 | Args: 12 | data (str): Path to a folder containing images to be annotated. 13 | det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. 14 | sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. 15 | device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available). 16 | output_dir (str | None | optional): Directory to save the annotated results. 17 | Defaults to a 'labels' folder in the same directory as 'data'. 18 | """ 19 | det_model = YOLO(det_model) 20 | sam_model = SAM(sam_model) 21 | 22 | if not output_dir: 23 | output_dir = Path(str(data)).parent / 'labels' 24 | Path(output_dir).mkdir(exist_ok=True, parents=True) 25 | 26 | det_results = det_model(data, stream=True, device=device) 27 | 28 | for result in det_results: 29 | class_ids = result.boxes.cls.int().tolist() # noqa 30 | if len(class_ids): 31 | boxes = result.boxes.xyxy # Boxes object for bbox outputs 32 | sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) 33 | segments = sam_results[0].masks.xyn # noqa 34 | 35 | with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f: 36 | for i in range(len(segments)): 37 | s = segments[i] 38 | if len(s) == 0: 39 | continue 40 | segment = map(str, segments[i].reshape(-1).tolist()) 41 | f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') 42 | -------------------------------------------------------------------------------- /ultralytics/data/scripts/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ultralytics YOLO 🚀, AGPL-3.0 license 3 | # Download latest models from https://github.com/ultralytics/assets/releases 4 | # Example usage: bash ultralytics/data/scripts/download_weights.sh 5 | # parent 6 | # └── weights 7 | # ├── yolov8n.pt ← downloads here 8 | # ├── yolov8s.pt 9 | # └── ... 10 | 11 | python - < bool: 69 | """ 70 | Attempt to authenticate with the server using either id_token or API key. 71 | 72 | Returns: 73 | bool: True if authentication is successful, False otherwise. 74 | """ 75 | try: 76 | header = self.get_auth_header() 77 | if header: 78 | r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header) 79 | if not r.json().get('success', False): 80 | raise ConnectionError('Unable to authenticate.') 81 | return True 82 | raise ConnectionError('User has not authenticated locally.') 83 | except ConnectionError: 84 | self.id_token = self.api_key = False # reset invalid 85 | LOGGER.warning(f'{PREFIX}Invalid API key ⚠️') 86 | return False 87 | 88 | def auth_with_cookies(self) -> bool: 89 | """ 90 | Attempt to fetch authentication via cookies and set id_token. 91 | User must be logged in to HUB and running in a supported browser. 92 | 93 | Returns: 94 | bool: True if authentication is successful, False otherwise. 95 | """ 96 | if not is_colab(): 97 | return False # Currently only works with Colab 98 | try: 99 | authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto') 100 | if authn.get('success', False): 101 | self.id_token = authn.get('data', {}).get('idToken', None) 102 | self.authenticate() 103 | return True 104 | raise ConnectionError('Unable to fetch browser authentication details.') 105 | except ConnectionError: 106 | self.id_token = False # reset invalid 107 | return False 108 | 109 | def get_auth_header(self): 110 | """ 111 | Get the authentication header for making API requests. 112 | 113 | Returns: 114 | (dict): The authentication header if id_token or API key is set, None otherwise. 115 | """ 116 | if self.id_token: 117 | return {'authorization': f'Bearer {self.id_token}'} 118 | elif self.api_key: 119 | return {'x-api-key': self.api_key} 120 | else: 121 | return None 122 | 123 | def get_state(self) -> bool: 124 | """ 125 | Get the authentication state. 126 | 127 | Returns: 128 | bool: True if either id_token or API key is set, False otherwise. 129 | """ 130 | return self.id_token or self.api_key 131 | 132 | def set_api_key(self, key: str): 133 | """ 134 | Set the API key for authentication. 135 | 136 | Args: 137 | key (str): The API key string. 138 | """ 139 | self.api_key = key 140 | -------------------------------------------------------------------------------- /ultralytics/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .rtdetr import RTDETR 4 | from .sam import SAM 5 | from .yolo import YOLO 6 | 7 | __all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import 8 | -------------------------------------------------------------------------------- /ultralytics/models/fastsam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import FastSAM 4 | from .predict import FastSAMPredictor 5 | from .prompt import FastSAMPrompt 6 | from .val import FastSAMValidator 7 | 8 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator' 9 | -------------------------------------------------------------------------------- /ultralytics/models/fastsam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from pathlib import Path 4 | 5 | from ultralytics.engine.model import Model 6 | 7 | from .predict import FastSAMPredictor 8 | from .val import FastSAMValidator 9 | 10 | 11 | class FastSAM(Model): 12 | """ 13 | FastSAM model interface. 14 | 15 | Example: 16 | ```python 17 | from ultralytics import FastSAM 18 | 19 | model = FastSAM('last.pt') 20 | results = model.predict('ultralytics/assets/bus.jpg') 21 | ``` 22 | """ 23 | 24 | def __init__(self, model='FastSAM-x.pt'): 25 | """Call the __init__ method of the parent class (YOLO) with the updated default model""" 26 | if str(model) == 'FastSAM.pt': 27 | model = 'FastSAM-x.pt' 28 | assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.' 29 | super().__init__(model=model, task='segment') 30 | 31 | @property 32 | def task_map(self): 33 | return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}} 34 | -------------------------------------------------------------------------------- /ultralytics/models/fastsam/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.engine.results import Results 6 | from ultralytics.models.fastsam.utils import bbox_iou 7 | from ultralytics.models.yolo.detect.predict import DetectionPredictor 8 | from ultralytics.utils import DEFAULT_CFG, ops 9 | 10 | 11 | class FastSAMPredictor(DetectionPredictor): 12 | 13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 14 | super().__init__(cfg, overrides, _callbacks) 15 | self.args.task = 'segment' 16 | 17 | def postprocess(self, preds, img, orig_imgs): 18 | """TODO: filter by classes.""" 19 | p = ops.non_max_suppression(preds[0], 20 | self.args.conf, 21 | self.args.iou, 22 | agnostic=self.args.agnostic_nms, 23 | max_det=self.args.max_det, 24 | nc=len(self.model.names), 25 | classes=self.args.classes) 26 | full_box = torch.zeros_like(p[0][0]) 27 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 28 | full_box = full_box.view(1, -1) 29 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) 30 | if critical_iou_index.numel() != 0: 31 | full_box[0][4] = p[0][critical_iou_index][:, 4] 32 | full_box[0][6:] = p[0][critical_iou_index][:, 6:] 33 | p[0][critical_iou_index] = full_box 34 | results = [] 35 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 36 | for i, pred in enumerate(p): 37 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 38 | path = self.batch[0] 39 | img_path = path[i] if isinstance(path, list) else path 40 | if not len(pred): # save empty boxes 41 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 42 | continue 43 | if self.args.retina_masks: 44 | if not isinstance(orig_imgs, torch.Tensor): 45 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 46 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 47 | else: 48 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 49 | if not isinstance(orig_imgs, torch.Tensor): 50 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 51 | results.append( 52 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 53 | return results 54 | -------------------------------------------------------------------------------- /ultralytics/models/fastsam/utils.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | 6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): 7 | """ 8 | Adjust bounding boxes to stick to image border if they are within a certain threshold. 9 | 10 | Args: 11 | boxes (torch.Tensor): (n, 4) 12 | image_shape (tuple): (height, width) 13 | threshold (int): pixel threshold 14 | 15 | Returns: 16 | adjusted_boxes (torch.Tensor): adjusted bounding boxes 17 | """ 18 | 19 | # Image dimensions 20 | h, w = image_shape 21 | 22 | # Adjust boxes 23 | boxes[boxes[:, 0] < threshold, 0] = 0 # x1 24 | boxes[boxes[:, 1] < threshold, 1] = 0 # y1 25 | boxes[boxes[:, 2] > w - threshold, 2] = w # x2 26 | boxes[boxes[:, 3] > h - threshold, 3] = h # y2 27 | return boxes 28 | 29 | 30 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): 31 | """ 32 | Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. 33 | 34 | Args: 35 | box1 (torch.Tensor): (4, ) 36 | boxes (torch.Tensor): (n, 4) 37 | iou_thres (float): IoU threshold 38 | image_shape (tuple): (height, width) 39 | raw_output (bool): If True, return the raw IoU values instead of the indices 40 | 41 | Returns: 42 | high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres 43 | """ 44 | boxes = adjust_bboxes_to_image_border(boxes, image_shape) 45 | # obtain coordinates for intersections 46 | x1 = torch.max(box1[0], boxes[:, 0]) 47 | y1 = torch.max(box1[1], boxes[:, 1]) 48 | x2 = torch.min(box1[2], boxes[:, 2]) 49 | y2 = torch.min(box1[3], boxes[:, 3]) 50 | 51 | # compute the area of intersection 52 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) 53 | 54 | # compute the area of both individual boxes 55 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 56 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 57 | 58 | # compute the area of union 59 | union = box1_area + box2_area - intersection 60 | 61 | # compute the IoU 62 | iou = intersection / union # Should be shape (n, ) 63 | if raw_output: 64 | return 0 if iou.numel() == 0 else iou 65 | 66 | # return indices of boxes with IoU > thres 67 | return torch.nonzero(iou > iou_thres).flatten() 68 | -------------------------------------------------------------------------------- /ultralytics/models/fastsam/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.models.yolo.segment import SegmentationValidator 4 | from ultralytics.utils.metrics import SegmentMetrics 5 | 6 | 7 | class FastSAMValidator(SegmentationValidator): 8 | 9 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): 10 | """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" 11 | super().__init__(dataloader, save_dir, pbar, args, _callbacks) 12 | self.args.task = 'segment' 13 | self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors 14 | self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) 15 | -------------------------------------------------------------------------------- /ultralytics/models/nas/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import NAS 4 | from .predict import NASPredictor 5 | from .val import NASValidator 6 | 7 | __all__ = 'NASPredictor', 'NASValidator', 'NAS' 8 | -------------------------------------------------------------------------------- /ultralytics/models/nas/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | YOLO-NAS model interface. 4 | 5 | Example: 6 | ```python 7 | from ultralytics import NAS 8 | 9 | model = NAS('yolo_nas_s') 10 | results = model.predict('ultralytics/assets/bus.jpg') 11 | ``` 12 | """ 13 | 14 | from pathlib import Path 15 | 16 | import torch 17 | 18 | from ultralytics.engine.model import Model 19 | from ultralytics.utils.torch_utils import model_info, smart_inference_mode 20 | 21 | from .predict import NASPredictor 22 | from .val import NASValidator 23 | 24 | 25 | class NAS(Model): 26 | 27 | def __init__(self, model='yolo_nas_s.pt') -> None: 28 | assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' 29 | super().__init__(model, task='detect') 30 | 31 | @smart_inference_mode() 32 | def _load(self, weights: str, task: str): 33 | # Load or create new NAS model 34 | import super_gradients 35 | suffix = Path(weights).suffix 36 | if suffix == '.pt': 37 | self.model = torch.load(weights) 38 | elif suffix == '': 39 | self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') 40 | # Standardize model 41 | self.model.fuse = lambda verbose=True: self.model 42 | self.model.stride = torch.tensor([32]) 43 | self.model.names = dict(enumerate(self.model._class_names)) 44 | self.model.is_fused = lambda: False # for info() 45 | self.model.yaml = {} # for info() 46 | self.model.pt_path = weights # for export() 47 | self.model.task = 'detect' # for export() 48 | 49 | def info(self, detailed=False, verbose=True): 50 | """ 51 | Logs model info. 52 | 53 | Args: 54 | detailed (bool): Show detailed information about model. 55 | verbose (bool): Controls verbosity. 56 | """ 57 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 58 | 59 | @property 60 | def task_map(self): 61 | return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} 62 | -------------------------------------------------------------------------------- /ultralytics/models/nas/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.engine.predictor import BasePredictor 6 | from ultralytics.engine.results import Results 7 | from ultralytics.utils import ops 8 | from ultralytics.utils.ops import xyxy2xywh 9 | 10 | 11 | class NASPredictor(BasePredictor): 12 | 13 | def postprocess(self, preds_in, img, orig_imgs): 14 | """Postprocess predictions and returns a list of Results objects.""" 15 | 16 | # Cat boxes and class scores 17 | boxes = xyxy2xywh(preds_in[0][0]) 18 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) 19 | 20 | preds = ops.non_max_suppression(preds, 21 | self.args.conf, 22 | self.args.iou, 23 | agnostic=self.args.agnostic_nms, 24 | max_det=self.args.max_det, 25 | classes=self.args.classes) 26 | 27 | results = [] 28 | for i, pred in enumerate(preds): 29 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 30 | if not isinstance(orig_imgs, torch.Tensor): 31 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 32 | path = self.batch[0] 33 | img_path = path[i] if isinstance(path, list) else path 34 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 35 | return results 36 | -------------------------------------------------------------------------------- /ultralytics/models/nas/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.models.yolo.detect import DetectionValidator 6 | from ultralytics.utils import ops 7 | from ultralytics.utils.ops import xyxy2xywh 8 | 9 | __all__ = ['NASValidator'] 10 | 11 | 12 | class NASValidator(DetectionValidator): 13 | 14 | def postprocess(self, preds_in): 15 | """Apply Non-maximum suppression to prediction outputs.""" 16 | boxes = xyxy2xywh(preds_in[0][0]) 17 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) 18 | return ops.non_max_suppression(preds, 19 | self.args.conf, 20 | self.args.iou, 21 | labels=self.lb, 22 | multi_label=False, 23 | agnostic=self.args.single_cls, 24 | max_det=self.args.max_det, 25 | max_time_img=0.5) 26 | -------------------------------------------------------------------------------- /ultralytics/models/rtdetr/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import RTDETR 4 | from .predict import RTDETRPredictor 5 | from .val import RTDETRValidator 6 | 7 | __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR' 8 | -------------------------------------------------------------------------------- /ultralytics/models/rtdetr/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | RT-DETR model interface 4 | """ 5 | from ultralytics.engine.model import Model 6 | from ultralytics.nn.tasks import RTDETRDetectionModel 7 | 8 | from .predict import RTDETRPredictor 9 | from .train import RTDETRTrainer 10 | from .val import RTDETRValidator 11 | 12 | 13 | class RTDETR(Model): 14 | """ 15 | RTDETR model interface. 16 | """ 17 | 18 | def __init__(self, model='rtdetr-l.pt') -> None: 19 | if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'): 20 | raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.') 21 | super().__init__(model=model, task='detect') 22 | 23 | @property 24 | def task_map(self): 25 | return { 26 | 'detect': { 27 | 'predictor': RTDETRPredictor, 28 | 'validator': RTDETRValidator, 29 | 'trainer': RTDETRTrainer, 30 | 'model': RTDETRDetectionModel}} 31 | -------------------------------------------------------------------------------- /ultralytics/models/rtdetr/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.data.augment import LetterBox 6 | from ultralytics.engine.predictor import BasePredictor 7 | from ultralytics.engine.results import Results 8 | from ultralytics.utils import ops 9 | 10 | 11 | class RTDETRPredictor(BasePredictor): 12 | 13 | def postprocess(self, preds, img, orig_imgs): 14 | """Postprocess predictions and returns a list of Results objects.""" 15 | nd = preds[0].shape[-1] 16 | bboxes, scores = preds[0].split((4, nd - 4), dim=-1) 17 | results = [] 18 | for i, bbox in enumerate(bboxes): # (300, 4) 19 | bbox = ops.xywh2xyxy(bbox) 20 | score, cls = scores[i].max(-1, keepdim=True) # (300, 1) 21 | idx = score.squeeze(-1) > self.args.conf # (300, ) 22 | if self.args.classes is not None: 23 | idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx 24 | pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter 25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 26 | oh, ow = orig_img.shape[:2] 27 | if not isinstance(orig_imgs, torch.Tensor): 28 | pred[..., [0, 2]] *= ow 29 | pred[..., [1, 3]] *= oh 30 | path = self.batch[0] 31 | img_path = path[i] if isinstance(path, list) else path 32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 33 | return results 34 | 35 | def pre_transform(self, im): 36 | """Pre-transform input image before inference. 37 | 38 | Args: 39 | im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. 40 | 41 | Return: A list of transformed imgs. 42 | """ 43 | # The size must be square(640) and scaleFilled. 44 | return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im] 45 | -------------------------------------------------------------------------------- /ultralytics/models/rtdetr/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | import torch 6 | 7 | from ultralytics.models.yolo.detect import DetectionTrainer 8 | from ultralytics.nn.tasks import RTDETRDetectionModel 9 | from ultralytics.utils import DEFAULT_CFG, RANK, colorstr 10 | 11 | from .val import RTDETRDataset, RTDETRValidator 12 | 13 | 14 | class RTDETRTrainer(DetectionTrainer): 15 | 16 | def get_model(self, cfg=None, weights=None, verbose=True): 17 | """Return a YOLO detection model.""" 18 | model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) 19 | if weights: 20 | model.load(weights) 21 | return model 22 | 23 | def build_dataset(self, img_path, mode='val', batch=None): 24 | """Build RTDETR Dataset 25 | 26 | Args: 27 | img_path (str): Path to the folder containing images. 28 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. 29 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None. 30 | """ 31 | return RTDETRDataset( 32 | img_path=img_path, 33 | imgsz=self.args.imgsz, 34 | batch_size=batch, 35 | augment=mode == 'train', # no augmentation 36 | hyp=self.args, 37 | rect=False, # no rect 38 | cache=self.args.cache or None, 39 | prefix=colorstr(f'{mode}: '), 40 | data=self.data) 41 | 42 | def get_validator(self): 43 | """Returns a DetectionValidator for RTDETR model validation.""" 44 | self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' 45 | return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 46 | 47 | def preprocess_batch(self, batch): 48 | """Preprocesses a batch of images by scaling and converting to float.""" 49 | batch = super().preprocess_batch(batch) 50 | bs = len(batch['img']) 51 | batch_idx = batch['batch_idx'] 52 | gt_bbox, gt_class = [], [] 53 | for i in range(bs): 54 | gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) 55 | gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) 56 | return batch 57 | 58 | 59 | def train(cfg=DEFAULT_CFG, use_python=False): 60 | """Train and optimize RTDETR model given training data and device.""" 61 | model = 'rtdetr-l.yaml' 62 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") 63 | device = cfg.device if cfg.device is not None else '' 64 | 65 | # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True 66 | # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching 67 | args = dict(model=model, 68 | data=data, 69 | device=device, 70 | imgsz=640, 71 | exist_ok=True, 72 | batch=4, 73 | deterministic=False, 74 | amp=False) 75 | trainer = RTDETRTrainer(overrides=args) 76 | trainer.train() 77 | 78 | 79 | if __name__ == '__main__': 80 | train() 81 | -------------------------------------------------------------------------------- /ultralytics/models/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import SAM 4 | from .predict import Predictor 5 | 6 | # from .build import build_sam 7 | 8 | __all__ = 'SAM', 'Predictor' # tuple or list 9 | -------------------------------------------------------------------------------- /ultralytics/models/sam/build.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from functools import partial 10 | 11 | import torch 12 | 13 | from ultralytics.utils.downloads import attempt_download_asset 14 | 15 | from .modules.decoders import MaskDecoder 16 | from .modules.encoders import ImageEncoderViT, PromptEncoder 17 | from .modules.sam import Sam 18 | from .modules.tiny_encoder import TinyViT 19 | from .modules.transformer import TwoWayTransformer 20 | 21 | 22 | def build_sam_vit_h(checkpoint=None): 23 | """Build and return a Segment Anything Model (SAM) h-size model.""" 24 | return _build_sam( 25 | encoder_embed_dim=1280, 26 | encoder_depth=32, 27 | encoder_num_heads=16, 28 | encoder_global_attn_indexes=[7, 15, 23, 31], 29 | checkpoint=checkpoint, 30 | ) 31 | 32 | 33 | def build_sam_vit_l(checkpoint=None): 34 | """Build and return a Segment Anything Model (SAM) l-size model.""" 35 | return _build_sam( 36 | encoder_embed_dim=1024, 37 | encoder_depth=24, 38 | encoder_num_heads=16, 39 | encoder_global_attn_indexes=[5, 11, 17, 23], 40 | checkpoint=checkpoint, 41 | ) 42 | 43 | 44 | def build_sam_vit_b(checkpoint=None): 45 | """Build and return a Segment Anything Model (SAM) b-size model.""" 46 | return _build_sam( 47 | encoder_embed_dim=768, 48 | encoder_depth=12, 49 | encoder_num_heads=12, 50 | encoder_global_attn_indexes=[2, 5, 8, 11], 51 | checkpoint=checkpoint, 52 | ) 53 | 54 | 55 | def build_mobile_sam(checkpoint=None): 56 | """Build and return Mobile Segment Anything Model (Mobile-SAM).""" 57 | return _build_sam( 58 | encoder_embed_dim=[64, 128, 160, 320], 59 | encoder_depth=[2, 2, 6, 2], 60 | encoder_num_heads=[2, 4, 5, 10], 61 | encoder_global_attn_indexes=None, 62 | mobile_sam=True, 63 | checkpoint=checkpoint, 64 | ) 65 | 66 | 67 | def _build_sam(encoder_embed_dim, 68 | encoder_depth, 69 | encoder_num_heads, 70 | encoder_global_attn_indexes, 71 | checkpoint=None, 72 | mobile_sam=False): 73 | """Builds the selected SAM model architecture.""" 74 | prompt_embed_dim = 256 75 | image_size = 1024 76 | vit_patch_size = 16 77 | image_embedding_size = image_size // vit_patch_size 78 | image_encoder = (TinyViT( 79 | img_size=1024, 80 | in_chans=3, 81 | num_classes=1000, 82 | embed_dims=encoder_embed_dim, 83 | depths=encoder_depth, 84 | num_heads=encoder_num_heads, 85 | window_sizes=[7, 7, 14, 7], 86 | mlp_ratio=4.0, 87 | drop_rate=0.0, 88 | drop_path_rate=0.0, 89 | use_checkpoint=False, 90 | mbconv_expand_ratio=4.0, 91 | local_conv_size=3, 92 | layer_lr_decay=0.8, 93 | ) if mobile_sam else ImageEncoderViT( 94 | depth=encoder_depth, 95 | embed_dim=encoder_embed_dim, 96 | img_size=image_size, 97 | mlp_ratio=4, 98 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 99 | num_heads=encoder_num_heads, 100 | patch_size=vit_patch_size, 101 | qkv_bias=True, 102 | use_rel_pos=True, 103 | global_attn_indexes=encoder_global_attn_indexes, 104 | window_size=14, 105 | out_chans=prompt_embed_dim, 106 | )) 107 | sam = Sam( 108 | image_encoder=image_encoder, 109 | prompt_encoder=PromptEncoder( 110 | embed_dim=prompt_embed_dim, 111 | image_embedding_size=(image_embedding_size, image_embedding_size), 112 | input_image_size=(image_size, image_size), 113 | mask_in_chans=16, 114 | ), 115 | mask_decoder=MaskDecoder( 116 | num_multimask_outputs=3, 117 | transformer=TwoWayTransformer( 118 | depth=2, 119 | embedding_dim=prompt_embed_dim, 120 | mlp_dim=2048, 121 | num_heads=8, 122 | ), 123 | transformer_dim=prompt_embed_dim, 124 | iou_head_depth=3, 125 | iou_head_hidden_dim=256, 126 | ), 127 | pixel_mean=[123.675, 116.28, 103.53], 128 | pixel_std=[58.395, 57.12, 57.375], 129 | ) 130 | if checkpoint is not None: 131 | checkpoint = attempt_download_asset(checkpoint) 132 | with open(checkpoint, 'rb') as f: 133 | state_dict = torch.load(f) 134 | sam.load_state_dict(state_dict) 135 | sam.eval() 136 | # sam.load_state_dict(torch.load(checkpoint), strict=True) 137 | # sam.eval() 138 | return sam 139 | 140 | 141 | sam_model_map = { 142 | 'sam_h.pt': build_sam_vit_h, 143 | 'sam_l.pt': build_sam_vit_l, 144 | 'sam_b.pt': build_sam_vit_b, 145 | 'mobile_sam.pt': build_mobile_sam, } 146 | 147 | 148 | def build_sam(ckpt='sam_b.pt'): 149 | """Build a SAM model specified by ckpt.""" 150 | model_builder = None 151 | for k in sam_model_map.keys(): 152 | if ckpt.endswith(k): 153 | model_builder = sam_model_map.get(k) 154 | 155 | if not model_builder: 156 | raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}') 157 | 158 | return model_builder(ckpt) 159 | -------------------------------------------------------------------------------- /ultralytics/models/sam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | SAM model interface 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | from ultralytics.engine.model import Model 9 | from ultralytics.utils.torch_utils import model_info 10 | 11 | from .build import build_sam 12 | from .predict import Predictor 13 | 14 | 15 | class SAM(Model): 16 | """ 17 | SAM model interface. 18 | """ 19 | 20 | def __init__(self, model='sam_b.pt') -> None: 21 | if model and Path(model).suffix not in ('.pt', '.pth'): 22 | raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.') 23 | super().__init__(model=model, task='segment') 24 | 25 | def _load(self, weights: str, task=None): 26 | self.model = build_sam(weights) 27 | 28 | def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): 29 | """Predicts and returns segmentation masks for given image or video source.""" 30 | overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) 31 | kwargs.update(overrides) 32 | prompts = dict(bboxes=bboxes, points=points, labels=labels) 33 | return super().predict(source, stream, prompts=prompts, **kwargs) 34 | 35 | def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): 36 | """Calls the 'predict' function with given arguments to perform object detection.""" 37 | return self.predict(source, stream, bboxes, points, labels, **kwargs) 38 | 39 | def info(self, detailed=False, verbose=True): 40 | """ 41 | Logs model info. 42 | 43 | Args: 44 | detailed (bool): Show detailed information about model. 45 | verbose (bool): Controls verbosity. 46 | """ 47 | return model_info(self.model, detailed=detailed, verbose=verbose) 48 | 49 | @property 50 | def task_map(self): 51 | return {'segment': {'predictor': Predictor}} 52 | -------------------------------------------------------------------------------- /ultralytics/models/sam/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /ultralytics/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.models.yolo import classify, detect, pose, segment 4 | 5 | from .model import YOLO 6 | 7 | __all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO' 8 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/classify/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict 4 | from ultralytics.models.yolo.classify.train import ClassificationTrainer, train 5 | from ultralytics.models.yolo.classify.val import ClassificationValidator, val 6 | 7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/classify/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.engine.predictor import BasePredictor 6 | from ultralytics.engine.results import Results 7 | from ultralytics.utils import DEFAULT_CFG, ROOT 8 | 9 | 10 | class ClassificationPredictor(BasePredictor): 11 | 12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 13 | super().__init__(cfg, overrides, _callbacks) 14 | self.args.task = 'classify' 15 | 16 | def preprocess(self, img): 17 | """Converts input image to model-compatible data type.""" 18 | if not isinstance(img, torch.Tensor): 19 | img = torch.stack([self.transforms(im) for im in img], dim=0) 20 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 21 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 22 | 23 | def postprocess(self, preds, img, orig_imgs): 24 | """Post-processes predictions to return Results objects.""" 25 | results = [] 26 | for i, pred in enumerate(preds): 27 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 28 | path = self.batch[0] 29 | img_path = path[i] if isinstance(path, list) else path 30 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) 31 | 32 | return results 33 | 34 | 35 | def predict(cfg=DEFAULT_CFG, use_python=False): 36 | """Run YOLO model predictions on input images/videos.""" 37 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 38 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 39 | else 'https://ultralytics.com/images/bus.jpg' 40 | 41 | args = dict(model=model, source=source) 42 | if use_python: 43 | from ultralytics import YOLO 44 | YOLO(model)(**args) 45 | else: 46 | predictor = ClassificationPredictor(overrides=args) 47 | predictor.predict_cli() 48 | 49 | 50 | if __name__ == '__main__': 51 | predict() 52 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/classify/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.data import ClassificationDataset, build_dataloader 6 | from ultralytics.engine.validator import BaseValidator 7 | from ultralytics.utils import DEFAULT_CFG, LOGGER 8 | from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix 9 | from ultralytics.utils.plotting import plot_images 10 | 11 | 12 | class ClassificationValidator(BaseValidator): 13 | 14 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): 15 | """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar.""" 16 | super().__init__(dataloader, save_dir, pbar, args, _callbacks) 17 | self.targets = None 18 | self.pred = None 19 | self.args.task = 'classify' 20 | self.metrics = ClassifyMetrics() 21 | 22 | def get_desc(self): 23 | """Returns a formatted string summarizing classification metrics.""" 24 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') 25 | 26 | def init_metrics(self, model): 27 | """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" 28 | self.names = model.names 29 | self.nc = len(model.names) 30 | self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify') 31 | self.pred = [] 32 | self.targets = [] 33 | 34 | def preprocess(self, batch): 35 | """Preprocesses input batch and returns it.""" 36 | batch['img'] = batch['img'].to(self.device, non_blocking=True) 37 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() 38 | batch['cls'] = batch['cls'].to(self.device) 39 | return batch 40 | 41 | def update_metrics(self, preds, batch): 42 | """Updates running metrics with model predictions and batch targets.""" 43 | n5 = min(len(self.model.names), 5) 44 | self.pred.append(preds.argsort(1, descending=True)[:, :n5]) 45 | self.targets.append(batch['cls']) 46 | 47 | def finalize_metrics(self, *args, **kwargs): 48 | """Finalizes metrics of the model such as confusion_matrix and speed.""" 49 | self.confusion_matrix.process_cls_preds(self.pred, self.targets) 50 | if self.args.plots: 51 | for normalize in True, False: 52 | self.confusion_matrix.plot(save_dir=self.save_dir, 53 | names=self.names.values(), 54 | normalize=normalize, 55 | on_plot=self.on_plot) 56 | self.metrics.speed = self.speed 57 | self.metrics.confusion_matrix = self.confusion_matrix 58 | 59 | def get_stats(self): 60 | """Returns a dictionary of metrics obtained by processing targets and predictions.""" 61 | self.metrics.process(self.targets, self.pred) 62 | return self.metrics.results_dict 63 | 64 | def build_dataset(self, img_path): 65 | return ClassificationDataset(root=img_path, args=self.args, augment=False) 66 | 67 | def get_dataloader(self, dataset_path, batch_size): 68 | """Builds and returns a data loader for classification tasks with given parameters.""" 69 | dataset = self.build_dataset(dataset_path) 70 | return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) 71 | 72 | def print_results(self): 73 | """Prints evaluation metrics for YOLO object detection model.""" 74 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format 75 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) 76 | 77 | def plot_val_samples(self, batch, ni): 78 | """Plot validation image samples.""" 79 | plot_images( 80 | images=batch['img'], 81 | batch_idx=torch.arange(len(batch['img'])), 82 | cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models 83 | fname=self.save_dir / f'val_batch{ni}_labels.jpg', 84 | names=self.names, 85 | on_plot=self.on_plot) 86 | 87 | def plot_predictions(self, batch, preds, ni): 88 | """Plots predicted bounding boxes on input images and saves the result.""" 89 | plot_images(batch['img'], 90 | batch_idx=torch.arange(len(batch['img'])), 91 | cls=torch.argmax(preds, dim=1), 92 | fname=self.save_dir / f'val_batch{ni}_pred.jpg', 93 | names=self.names, 94 | on_plot=self.on_plot) # pred 95 | 96 | 97 | def val(cfg=DEFAULT_CFG, use_python=False): 98 | """Validate YOLO model using custom data.""" 99 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 100 | data = cfg.data or 'mnist160' 101 | 102 | args = dict(model=model, data=data) 103 | if use_python: 104 | from ultralytics import YOLO 105 | YOLO(model).val(**args) 106 | else: 107 | validator = ClassificationValidator(args=args) 108 | validator(model=args['model']) 109 | 110 | 111 | if __name__ == '__main__': 112 | val() 113 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/detect/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import DetectionPredictor, predict 4 | from .train import DetectionTrainer, train 5 | from .val import DetectionValidator, val 6 | 7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/detect/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.engine.predictor import BasePredictor 6 | from ultralytics.engine.results import Results 7 | from ultralytics.utils import DEFAULT_CFG, ROOT, ops 8 | 9 | 10 | class DetectionPredictor(BasePredictor): 11 | 12 | def postprocess(self, preds, img, orig_imgs): 13 | """Post-processes predictions and returns a list of Results objects.""" 14 | preds = ops.non_max_suppression(preds, 15 | self.args.conf, 16 | self.args.iou, 17 | agnostic=self.args.agnostic_nms, 18 | max_det=self.args.max_det, 19 | classes=self.args.classes) 20 | 21 | results = [] 22 | for i, pred in enumerate(preds): 23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 24 | if not isinstance(orig_imgs, torch.Tensor): 25 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 26 | path = self.batch[0] 27 | img_path = path[i] if isinstance(path, list) else path 28 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 29 | return results 30 | 31 | 32 | def predict(cfg=DEFAULT_CFG, use_python=False): 33 | """Runs YOLO model inference on input image(s).""" 34 | model = cfg.model or 'yolov8n.pt' 35 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 36 | else 'https://ultralytics.com/images/bus.jpg' 37 | 38 | args = dict(model=model, source=source) 39 | if use_python: 40 | from ultralytics import YOLO 41 | YOLO(model)(**args) 42 | else: 43 | predictor = DetectionPredictor(overrides=args) 44 | predictor.predict_cli() 45 | 46 | 47 | if __name__ == '__main__': 48 | predict() 49 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/detect/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | import numpy as np 6 | 7 | from ultralytics.data import build_dataloader, build_yolo_dataset 8 | from ultralytics.engine.trainer import BaseTrainer 9 | from ultralytics.models import yolo 10 | from ultralytics.nn.tasks import DetectionModel 11 | from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK 12 | from ultralytics.utils.plotting import plot_images, plot_labels, plot_results 13 | from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first 14 | 15 | 16 | class DetectionTrainer(BaseTrainer): 17 | 18 | def build_dataset(self, img_path, mode='train', batch=None): 19 | """ 20 | Build YOLO Dataset. 21 | 22 | Args: 23 | img_path (str): Path to the folder containing images. 24 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. 25 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None. 26 | """ 27 | gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) 28 | return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs) 29 | 30 | def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): 31 | """Construct and return dataloader.""" 32 | assert mode in ['train', 'val'] 33 | with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP 34 | dataset = self.build_dataset(dataset_path, mode, batch_size) 35 | shuffle = mode == 'train' 36 | if getattr(dataset, 'rect', False) and shuffle: 37 | LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") 38 | shuffle = False 39 | workers = self.args.workers if mode == 'train' else self.args.workers * 2 40 | return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader 41 | 42 | def preprocess_batch(self, batch): 43 | """Preprocesses a batch of images by scaling and converting to float.""" 44 | batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255 45 | return batch 46 | 47 | def set_model_attributes(self): 48 | """nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).""" 49 | # self.args.box *= 3 / nl # scale to layers 50 | # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers 51 | # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers 52 | self.model.nc = self.data['nc'] # attach number of classes to model 53 | self.model.names = self.data['names'] # attach class names to model 54 | self.model.args = self.args # attach hyperparameters to model 55 | # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc 56 | 57 | def get_model(self, cfg=None, weights=None, verbose=True): 58 | """Return a YOLO detection model.""" 59 | model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) 60 | if weights: 61 | model.load(weights) 62 | return model 63 | 64 | def get_validator(self): 65 | """Returns a DetectionValidator for YOLO model validation.""" 66 | self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' 67 | return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 68 | 69 | def label_loss_items(self, loss_items=None, prefix='train'): 70 | """ 71 | Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for 72 | segmentation & detection 73 | """ 74 | keys = [f'{prefix}/{x}' for x in self.loss_names] 75 | if loss_items is not None: 76 | loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats 77 | return dict(zip(keys, loss_items)) 78 | else: 79 | return keys 80 | 81 | def progress_string(self): 82 | """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" 83 | return ('\n' + '%11s' * 84 | (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') 85 | 86 | def plot_training_samples(self, batch, ni): 87 | """Plots training samples with their annotations.""" 88 | plot_images(images=batch['img'], 89 | batch_idx=batch['batch_idx'], 90 | cls=batch['cls'].squeeze(-1), 91 | bboxes=batch['bboxes'], 92 | paths=batch['im_file'], 93 | fname=self.save_dir / f'train_batch{ni}.jpg', 94 | on_plot=self.on_plot) 95 | 96 | def plot_metrics(self): 97 | """Plots metrics from a CSV file.""" 98 | plot_results(file=self.csv, on_plot=self.on_plot) # save results.png 99 | 100 | def plot_training_labels(self): 101 | """Create a labeled training plot of the YOLO model.""" 102 | boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0) 103 | cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0) 104 | plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) 105 | 106 | 107 | def train(cfg=DEFAULT_CFG, use_python=False): 108 | """Train and optimize YOLO model given training data and device.""" 109 | model = cfg.model or 'yolov8n.pt' 110 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") 111 | device = cfg.device if cfg.device is not None else '' 112 | 113 | args = dict(model=model, data=data, device=device) 114 | if use_python: 115 | from ultralytics import YOLO 116 | YOLO(model).train(**args) 117 | else: 118 | trainer = DetectionTrainer(overrides=args) 119 | trainer.train() 120 | 121 | 122 | if __name__ == '__main__': 123 | train() 124 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.engine.model import Model 4 | from ultralytics.models import yolo # noqa 5 | from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel 6 | 7 | 8 | class YOLO(Model): 9 | """ 10 | YOLO (You Only Look Once) object detection model. 11 | """ 12 | 13 | @property 14 | def task_map(self): 15 | """Map head to model, trainer, validator, and predictor classes""" 16 | return { 17 | 'classify': { 18 | 'model': ClassificationModel, 19 | 'trainer': yolo.classify.ClassificationTrainer, 20 | 'validator': yolo.classify.ClassificationValidator, 21 | 'predictor': yolo.classify.ClassificationPredictor, }, 22 | 'detect': { 23 | 'model': DetectionModel, 24 | 'trainer': yolo.detect.DetectionTrainer, 25 | 'validator': yolo.detect.DetectionValidator, 26 | 'predictor': yolo.detect.DetectionPredictor, }, 27 | 'segment': { 28 | 'model': SegmentationModel, 29 | 'trainer': yolo.segment.SegmentationTrainer, 30 | 'validator': yolo.segment.SegmentationValidator, 31 | 'predictor': yolo.segment.SegmentationPredictor, }, 32 | 'pose': { 33 | 'model': PoseModel, 34 | 'trainer': yolo.pose.PoseTrainer, 35 | 'validator': yolo.pose.PoseValidator, 36 | 'predictor': yolo.pose.PosePredictor, }, } 37 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/pose/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import PosePredictor, predict 4 | from .train import PoseTrainer, train 5 | from .val import PoseValidator, val 6 | 7 | __all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict' 8 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/pose/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.engine.results import Results 4 | from ultralytics.models.yolo.detect.predict import DetectionPredictor 5 | from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, ops 6 | 7 | 8 | class PosePredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'pose' 13 | if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': 14 | LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " 15 | 'See https://github.com/ultralytics/ultralytics/issues/4031.') 16 | 17 | def postprocess(self, preds, img, orig_imgs): 18 | """Return detection results for a given input image or list of images.""" 19 | preds = ops.non_max_suppression(preds, 20 | self.args.conf, 21 | self.args.iou, 22 | agnostic=self.args.agnostic_nms, 23 | max_det=self.args.max_det, 24 | classes=self.args.classes, 25 | nc=len(self.model.names)) 26 | 27 | results = [] 28 | for i, pred in enumerate(preds): 29 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 30 | shape = orig_img.shape 31 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() 32 | pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:] 33 | pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape) 34 | path = self.batch[0] 35 | img_path = path[i] if isinstance(path, list) else path 36 | results.append( 37 | Results(orig_img=orig_img, 38 | path=img_path, 39 | names=self.model.names, 40 | boxes=pred[:, :6], 41 | keypoints=pred_kpts)) 42 | return results 43 | 44 | 45 | def predict(cfg=DEFAULT_CFG, use_python=False): 46 | """Runs YOLO to predict objects in an image or video.""" 47 | model = cfg.model or 'yolov8n-pose.pt' 48 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 49 | else 'https://ultralytics.com/images/bus.jpg' 50 | 51 | args = dict(model=model, source=source) 52 | if use_python: 53 | from ultralytics import YOLO 54 | YOLO(model)(**args) 55 | else: 56 | predictor = PosePredictor(overrides=args) 57 | predictor.predict_cli() 58 | 59 | 60 | if __name__ == '__main__': 61 | predict() 62 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/pose/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | from ultralytics.models import yolo 6 | from ultralytics.nn.tasks import PoseModel 7 | from ultralytics.utils import DEFAULT_CFG, LOGGER 8 | from ultralytics.utils.plotting import plot_images, plot_results 9 | 10 | 11 | class PoseTrainer(yolo.detect.DetectionTrainer): 12 | 13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 14 | """Initialize a PoseTrainer object with specified configurations and overrides.""" 15 | if overrides is None: 16 | overrides = {} 17 | overrides['task'] = 'pose' 18 | super().__init__(cfg, overrides, _callbacks) 19 | 20 | if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': 21 | LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " 22 | 'See https://github.com/ultralytics/ultralytics/issues/4031.') 23 | 24 | def get_model(self, cfg=None, weights=None, verbose=True): 25 | """Get pose estimation model with specified configuration and weights.""" 26 | model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) 27 | if weights: 28 | model.load(weights) 29 | 30 | return model 31 | 32 | def set_model_attributes(self): 33 | """Sets keypoints shape attribute of PoseModel.""" 34 | super().set_model_attributes() 35 | self.model.kpt_shape = self.data['kpt_shape'] 36 | 37 | def get_validator(self): 38 | """Returns an instance of the PoseValidator class for validation.""" 39 | self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' 40 | return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 41 | 42 | def plot_training_samples(self, batch, ni): 43 | """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" 44 | images = batch['img'] 45 | kpts = batch['keypoints'] 46 | cls = batch['cls'].squeeze(-1) 47 | bboxes = batch['bboxes'] 48 | paths = batch['im_file'] 49 | batch_idx = batch['batch_idx'] 50 | plot_images(images, 51 | batch_idx, 52 | cls, 53 | bboxes, 54 | kpts=kpts, 55 | paths=paths, 56 | fname=self.save_dir / f'train_batch{ni}.jpg', 57 | on_plot=self.on_plot) 58 | 59 | def plot_metrics(self): 60 | """Plots training/val metrics.""" 61 | plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png 62 | 63 | 64 | def train(cfg=DEFAULT_CFG, use_python=False): 65 | """Train the YOLO model on the given data and device.""" 66 | model = cfg.model or 'yolov8n-pose.yaml' 67 | data = cfg.data or 'coco8-pose.yaml' 68 | device = cfg.device if cfg.device is not None else '' 69 | 70 | args = dict(model=model, data=data, device=device) 71 | if use_python: 72 | from ultralytics import YOLO 73 | YOLO(model).train(**args) 74 | else: 75 | trainer = PoseTrainer(overrides=args) 76 | trainer.train() 77 | 78 | 79 | if __name__ == '__main__': 80 | train() 81 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/segment/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import SegmentationPredictor, predict 4 | from .train import SegmentationTrainer, train 5 | from .val import SegmentationValidator, val 6 | 7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/segment/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.engine.results import Results 6 | from ultralytics.models.yolo.detect.predict import DetectionPredictor 7 | from ultralytics.utils import DEFAULT_CFG, ROOT, ops 8 | 9 | 10 | class SegmentationPredictor(DetectionPredictor): 11 | 12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 13 | super().__init__(cfg, overrides, _callbacks) 14 | self.args.task = 'segment' 15 | 16 | def postprocess(self, preds, img, orig_imgs): 17 | """TODO: filter by classes.""" 18 | p = ops.non_max_suppression(preds[0], 19 | self.args.conf, 20 | self.args.iou, 21 | agnostic=self.args.agnostic_nms, 22 | max_det=self.args.max_det, 23 | nc=len(self.model.names), 24 | classes=self.args.classes) 25 | results = [] 26 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 27 | for i, pred in enumerate(p): 28 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 29 | path = self.batch[0] 30 | img_path = path[i] if isinstance(path, list) else path 31 | if not len(pred): # save empty boxes 32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 33 | continue 34 | if self.args.retina_masks: 35 | if not isinstance(orig_imgs, torch.Tensor): 36 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 37 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 38 | else: 39 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 40 | if not isinstance(orig_imgs, torch.Tensor): 41 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 42 | results.append( 43 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 44 | return results 45 | 46 | 47 | def predict(cfg=DEFAULT_CFG, use_python=False): 48 | """Runs YOLO object detection on an image or video source.""" 49 | model = cfg.model or 'yolov8n-seg.pt' 50 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 51 | else 'https://ultralytics.com/images/bus.jpg' 52 | 53 | args = dict(model=model, source=source) 54 | if use_python: 55 | from ultralytics import YOLO 56 | YOLO(model)(**args) 57 | else: 58 | predictor = SegmentationPredictor(overrides=args) 59 | predictor.predict_cli() 60 | 61 | 62 | if __name__ == '__main__': 63 | predict() 64 | -------------------------------------------------------------------------------- /ultralytics/models/yolo/segment/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | from ultralytics.models import yolo 6 | from ultralytics.nn.tasks import SegmentationModel 7 | from ultralytics.utils import DEFAULT_CFG, RANK 8 | from ultralytics.utils.plotting import plot_images, plot_results 9 | 10 | 11 | class SegmentationTrainer(yolo.detect.DetectionTrainer): 12 | 13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 14 | """Initialize a SegmentationTrainer object with given arguments.""" 15 | if overrides is None: 16 | overrides = {} 17 | overrides['task'] = 'segment' 18 | super().__init__(cfg, overrides, _callbacks) 19 | 20 | def get_model(self, cfg=None, weights=None, verbose=True): 21 | """Return SegmentationModel initialized with specified config and weights.""" 22 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) 23 | if weights: 24 | model.load(weights) 25 | 26 | return model 27 | 28 | def get_validator(self): 29 | """Return an instance of SegmentationValidator for validation of YOLO model.""" 30 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' 31 | return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 32 | 33 | def plot_training_samples(self, batch, ni): 34 | """Creates a plot of training sample images with labels and box coordinates.""" 35 | plot_images(batch['img'], 36 | batch['batch_idx'], 37 | batch['cls'].squeeze(-1), 38 | batch['bboxes'], 39 | batch['masks'], 40 | paths=batch['im_file'], 41 | fname=self.save_dir / f'train_batch{ni}.jpg', 42 | on_plot=self.on_plot) 43 | 44 | def plot_metrics(self): 45 | """Plots training/val metrics.""" 46 | plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png 47 | 48 | 49 | def train(cfg=DEFAULT_CFG, use_python=False): 50 | """Train a YOLO segmentation model based on passed arguments.""" 51 | model = cfg.model or 'yolov8n-seg.pt' 52 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist") 53 | device = cfg.device if cfg.device is not None else '' 54 | 55 | args = dict(model=model, data=data, device=device) 56 | if use_python: 57 | from ultralytics import YOLO 58 | YOLO(model).train(**args) 59 | else: 60 | trainer = SegmentationTrainer(overrides=args) 61 | trainer.train() 62 | 63 | 64 | if __name__ == '__main__': 65 | train() 66 | -------------------------------------------------------------------------------- /ultralytics/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, 4 | attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, 5 | yaml_model_load) 6 | 7 | __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task', 8 | 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel', 9 | 'BaseModel') 10 | -------------------------------------------------------------------------------- /ultralytics/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Ultralytics modules. Visualize with: 4 | 5 | from ultralytics.nn.modules import * 6 | import torch 7 | import os 8 | 9 | x = torch.ones(1, 128, 40, 40) 10 | m = Conv(128, 128) 11 | f = f'{m._get_name()}.onnx' 12 | torch.onnx.export(m, x, f) 13 | os.system(f'onnxsim {f} {f} && open {f}') 14 | """ 15 | 16 | from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, 17 | HGBlock, HGStem, Proto, RepC3) 18 | from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, 19 | GhostConv, LightConv, RepConv, SpatialAttention) 20 | from .head import Classify, Detect, Pose, RTDETRDecoder, Segment 21 | from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, 22 | MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) 23 | 24 | __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 25 | 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 26 | 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 27 | 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 28 | 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', 29 | 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') 30 | -------------------------------------------------------------------------------- /ultralytics/nn/modules/utils.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Module utils 4 | """ 5 | 6 | import copy 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.init import uniform_ 14 | 15 | __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' 16 | 17 | 18 | def _get_clones(module, n): 19 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) 20 | 21 | 22 | def bias_init_with_prob(prior_prob=0.01): 23 | """initialize conv/fc bias value according to a given probability value.""" 24 | return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init 25 | 26 | 27 | def linear_init_(module): 28 | bound = 1 / math.sqrt(module.weight.shape[0]) 29 | uniform_(module.weight, -bound, bound) 30 | if hasattr(module, 'bias') and module.bias is not None: 31 | uniform_(module.bias, -bound, bound) 32 | 33 | 34 | def inverse_sigmoid(x, eps=1e-5): 35 | x = x.clamp(min=0, max=1) 36 | x1 = x.clamp(min=eps) 37 | x2 = (1 - x).clamp(min=eps) 38 | return torch.log(x1 / x2) 39 | 40 | 41 | def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, 42 | sampling_locations: torch.Tensor, 43 | attention_weights: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Multi-scale deformable attention. 46 | https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py 47 | """ 48 | 49 | bs, _, num_heads, embed_dims = value.shape 50 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape 51 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 52 | sampling_grids = 2 * sampling_locations - 1 53 | sampling_value_list = [] 54 | for level, (H_, W_) in enumerate(value_spatial_shapes): 55 | # bs, H_*W_, num_heads, embed_dims -> 56 | # bs, H_*W_, num_heads*embed_dims -> 57 | # bs, num_heads*embed_dims, H_*W_ -> 58 | # bs*num_heads, embed_dims, H_, W_ 59 | value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) 60 | # bs, num_queries, num_heads, num_points, 2 -> 61 | # bs, num_heads, num_queries, num_points, 2 -> 62 | # bs*num_heads, num_queries, num_points, 2 63 | sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) 64 | # bs*num_heads, embed_dims, num_queries, num_points 65 | sampling_value_l_ = F.grid_sample(value_l_, 66 | sampling_grid_l_, 67 | mode='bilinear', 68 | padding_mode='zeros', 69 | align_corners=False) 70 | sampling_value_list.append(sampling_value_l_) 71 | # (bs, num_queries, num_heads, num_levels, num_points) -> 72 | # (bs, num_heads, num_queries, num_levels, num_points) -> 73 | # (bs, num_heads, 1, num_queries, num_levels*num_points) 74 | attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, 75 | num_levels * num_points) 76 | output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( 77 | bs, num_heads * embed_dims, num_queries)) 78 | return output.transpose(1, 2).contiguous() 79 | -------------------------------------------------------------------------------- /ultralytics/trackers/README.md: -------------------------------------------------------------------------------- 1 | # Tracker 2 | 3 | ## Supported Trackers 4 | 5 | - [x] ByteTracker 6 | - [x] BoT-SORT 7 | 8 | ## Usage 9 | 10 | ### python interface: 11 | 12 | You can use the Python interface to track objects using the YOLO model. 13 | 14 | ```python 15 | from ultralytics import YOLO 16 | 17 | model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt 18 | model.track( 19 | source="video/streams", 20 | stream=True, 21 | tracker="botsort.yaml", # or 'bytetrack.yaml' 22 | show=True, 23 | ) 24 | ``` 25 | 26 | You can get the IDs of the tracked objects using the following code: 27 | 28 | ```python 29 | from ultralytics import YOLO 30 | 31 | model = YOLO("yolov8n.pt") 32 | 33 | for result in model.track(source="video.mp4"): 34 | print( 35 | result.boxes.id.cpu().numpy().astype(int) 36 | ) # this will print the IDs of the tracked objects in the frame 37 | ``` 38 | 39 | If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking. 40 | 41 | ```python 42 | import cv2 43 | from ultralytics import YOLO 44 | 45 | cap = cv2.VideoCapture("video.mp4") 46 | model = YOLO("yolov8n.pt") 47 | while True: 48 | ret, frame = cap.read() 49 | if not ret: 50 | break 51 | results = model.track(frame, persist=True) 52 | boxes = results[0].boxes.xyxy.cpu().numpy().astype(int) 53 | ids = results[0].boxes.id.cpu().numpy().astype(int) 54 | for box, id in zip(boxes, ids): 55 | cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) 56 | cv2.putText( 57 | frame, 58 | f"Id {id}", 59 | (box[0], box[1]), 60 | cv2.FONT_HERSHEY_SIMPLEX, 61 | 1, 62 | (0, 0, 255), 63 | 2, 64 | ) 65 | cv2.imshow("frame", frame) 66 | if cv2.waitKey(1) & 0xFF == ord("q"): 67 | break 68 | ``` 69 | 70 | ## Change tracker parameters 71 | 72 | You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder. 73 | 74 | ## Command Line Interface (CLI) 75 | 76 | You can also use the command line interface to track objects using the YOLO model. 77 | 78 | ```bash 79 | yolo detect track source=... tracker=... 80 | yolo segment track source=... tracker=... 81 | yolo pose track source=... tracker=... 82 | ``` 83 | 84 | By default, trackers will use the configuration in `ultralytics/cfg/trackers`. 85 | We also support using a modified tracker config file. Please refer to the tracker config files 86 | in `ultralytics/cfg/trackers`.
87 | -------------------------------------------------------------------------------- /ultralytics/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .bot_sort import BOTSORT 4 | from .byte_tracker import BYTETracker 5 | from .track import register_tracker 6 | 7 | __all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import 8 | -------------------------------------------------------------------------------- /ultralytics/trackers/basetrack.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | 7 | 8 | class TrackState: 9 | """Enumeration of possible object tracking states.""" 10 | 11 | New = 0 12 | Tracked = 1 13 | Lost = 2 14 | Removed = 3 15 | 16 | 17 | class BaseTrack: 18 | """Base class for object tracking, handling basic track attributes and operations.""" 19 | 20 | _count = 0 21 | 22 | track_id = 0 23 | is_activated = False 24 | state = TrackState.New 25 | 26 | history = OrderedDict() 27 | features = [] 28 | curr_feature = None 29 | score = 0 30 | start_frame = 0 31 | frame_id = 0 32 | time_since_update = 0 33 | 34 | # Multi-camera 35 | location = (np.inf, np.inf) 36 | 37 | @property 38 | def end_frame(self): 39 | """Return the last frame ID of the track.""" 40 | return self.frame_id 41 | 42 | @staticmethod 43 | def next_id(): 44 | """Increment and return the global track ID counter.""" 45 | BaseTrack._count += 1 46 | return BaseTrack._count 47 | 48 | def activate(self, *args): 49 | """Activate the track with the provided arguments.""" 50 | raise NotImplementedError 51 | 52 | def predict(self): 53 | """Predict the next state of the track.""" 54 | raise NotImplementedError 55 | 56 | def update(self, *args, **kwargs): 57 | """Update the track with new observations.""" 58 | raise NotImplementedError 59 | 60 | def mark_lost(self): 61 | """Mark the track as lost.""" 62 | self.state = TrackState.Lost 63 | 64 | def mark_removed(self): 65 | """Mark the track as removed.""" 66 | self.state = TrackState.Removed 67 | 68 | @staticmethod 69 | def reset_id(): 70 | """Reset the global track ID counter.""" 71 | BaseTrack._count = 0 72 | -------------------------------------------------------------------------------- /ultralytics/trackers/bot_sort.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from collections import deque 4 | 5 | import numpy as np 6 | 7 | from .basetrack import TrackState 8 | from .byte_tracker import BYTETracker, STrack 9 | from .utils import matching 10 | from .utils.gmc import GMC 11 | from .utils.kalman_filter import KalmanFilterXYWH 12 | 13 | 14 | class BOTrack(STrack): 15 | shared_kalman = KalmanFilterXYWH() 16 | 17 | def __init__(self, tlwh, score, cls, feat=None, feat_history=50): 18 | """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features.""" 19 | super().__init__(tlwh, score, cls) 20 | 21 | self.smooth_feat = None 22 | self.curr_feat = None 23 | if feat is not None: 24 | self.update_features(feat) 25 | self.features = deque([], maxlen=feat_history) 26 | self.alpha = 0.9 27 | 28 | def update_features(self, feat): 29 | """Update features vector and smooth it using exponential moving average.""" 30 | feat /= np.linalg.norm(feat) 31 | self.curr_feat = feat 32 | if self.smooth_feat is None: 33 | self.smooth_feat = feat 34 | else: 35 | self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat 36 | self.features.append(feat) 37 | self.smooth_feat /= np.linalg.norm(self.smooth_feat) 38 | 39 | def predict(self): 40 | """Predicts the mean and covariance using Kalman filter.""" 41 | mean_state = self.mean.copy() 42 | if self.state != TrackState.Tracked: 43 | mean_state[6] = 0 44 | mean_state[7] = 0 45 | 46 | self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) 47 | 48 | def re_activate(self, new_track, frame_id, new_id=False): 49 | """Reactivates a track with updated features and optionally assigns a new ID.""" 50 | if new_track.curr_feat is not None: 51 | self.update_features(new_track.curr_feat) 52 | super().re_activate(new_track, frame_id, new_id) 53 | 54 | def update(self, new_track, frame_id): 55 | """Update the YOLOv8 instance with new track and frame ID.""" 56 | if new_track.curr_feat is not None: 57 | self.update_features(new_track.curr_feat) 58 | super().update(new_track, frame_id) 59 | 60 | @property 61 | def tlwh(self): 62 | """Get current position in bounding box format `(top left x, top left y, 63 | width, height)`. 64 | """ 65 | if self.mean is None: 66 | return self._tlwh.copy() 67 | ret = self.mean[:4].copy() 68 | ret[:2] -= ret[2:] / 2 69 | return ret 70 | 71 | @staticmethod 72 | def multi_predict(stracks): 73 | """Predicts the mean and covariance of multiple object tracks using shared Kalman filter.""" 74 | if len(stracks) <= 0: 75 | return 76 | multi_mean = np.asarray([st.mean.copy() for st in stracks]) 77 | multi_covariance = np.asarray([st.covariance for st in stracks]) 78 | for i, st in enumerate(stracks): 79 | if st.state != TrackState.Tracked: 80 | multi_mean[i][6] = 0 81 | multi_mean[i][7] = 0 82 | multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance) 83 | for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): 84 | stracks[i].mean = mean 85 | stracks[i].covariance = cov 86 | 87 | def convert_coords(self, tlwh): 88 | """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format.""" 89 | return self.tlwh_to_xywh(tlwh) 90 | 91 | @staticmethod 92 | def tlwh_to_xywh(tlwh): 93 | """Convert bounding box to format `(center x, center y, width, 94 | height)`. 95 | """ 96 | ret = np.asarray(tlwh).copy() 97 | ret[:2] += ret[2:] / 2 98 | return ret 99 | 100 | 101 | class BOTSORT(BYTETracker): 102 | 103 | def __init__(self, args, frame_rate=30): 104 | """Initialize YOLOv8 object with ReID module and GMC algorithm.""" 105 | super().__init__(args, frame_rate) 106 | # ReID module 107 | self.proximity_thresh = args.proximity_thresh 108 | self.appearance_thresh = args.appearance_thresh 109 | 110 | if args.with_reid: 111 | # Haven't supported BoT-SORT(reid) yet 112 | self.encoder = None 113 | # self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation]) 114 | self.gmc = GMC(method=args.cmc_method) 115 | 116 | def get_kalmanfilter(self): 117 | """Returns an instance of KalmanFilterXYWH for object tracking.""" 118 | return KalmanFilterXYWH() 119 | 120 | def init_track(self, dets, scores, cls, img=None): 121 | """Initialize track with detections, scores, and classes.""" 122 | if len(dets) == 0: 123 | return [] 124 | if self.args.with_reid and self.encoder is not None: 125 | features_keep = self.encoder.inference(img, dets) 126 | return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections 127 | else: 128 | return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections 129 | 130 | def get_dists(self, tracks, detections): 131 | """Get distances between tracks and detections using IoU and (optionally) ReID embeddings.""" 132 | dists = matching.iou_distance(tracks, detections) 133 | dists_mask = (dists > self.proximity_thresh) 134 | 135 | # TODO: mot20 136 | # if not self.args.mot20: 137 | dists = matching.fuse_score(dists, detections) 138 | 139 | if self.args.with_reid and self.encoder is not None: 140 | emb_dists = matching.embedding_distance(tracks, detections) / 2.0 141 | emb_dists[emb_dists > self.appearance_thresh] = 1.0 142 | emb_dists[dists_mask] = 1.0 143 | dists = np.minimum(dists, emb_dists) 144 | return dists 145 | 146 | def multi_predict(self, tracks): 147 | """Predict and track multiple objects with YOLOv8 model.""" 148 | BOTrack.multi_predict(tracks) 149 | -------------------------------------------------------------------------------- /ultralytics/trackers/track.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from functools import partial 4 | 5 | import torch 6 | 7 | from ultralytics.utils import IterableSimpleNamespace, yaml_load 8 | from ultralytics.utils.checks import check_yaml 9 | 10 | from .bot_sort import BOTSORT 11 | from .byte_tracker import BYTETracker 12 | 13 | TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT} 14 | 15 | 16 | def on_predict_start(predictor, persist=False): 17 | """ 18 | Initialize trackers for object tracking during prediction. 19 | 20 | Args: 21 | predictor (object): The predictor object to initialize trackers for. 22 | persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. 23 | 24 | Raises: 25 | AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. 26 | """ 27 | if hasattr(predictor, 'trackers') and persist: 28 | return 29 | tracker = check_yaml(predictor.args.tracker) 30 | cfg = IterableSimpleNamespace(**yaml_load(tracker)) 31 | assert cfg.tracker_type in ['bytetrack', 'botsort'], \ 32 | f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'" 33 | trackers = [] 34 | for _ in range(predictor.dataset.bs): 35 | tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) 36 | trackers.append(tracker) 37 | predictor.trackers = trackers 38 | 39 | 40 | def on_predict_postprocess_end(predictor): 41 | """Postprocess detected boxes and update with object tracking.""" 42 | bs = predictor.dataset.bs 43 | im0s = predictor.batch[1] 44 | for i in range(bs): 45 | det = predictor.results[i].boxes.cpu().numpy() 46 | if len(det) == 0: 47 | continue 48 | tracks = predictor.trackers[i].update(det, im0s[i]) 49 | if len(tracks) == 0: 50 | continue 51 | idx = tracks[:, -1].astype(int) 52 | predictor.results[i] = predictor.results[i][idx] 53 | predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1])) 54 | 55 | 56 | def register_tracker(model, persist): 57 | """ 58 | Register tracking callbacks to the model for object tracking during prediction. 59 | 60 | Args: 61 | model (object): The model object to register tracking callbacks for. 62 | persist (bool): Whether to persist the trackers if they already exist. 63 | 64 | """ 65 | model.add_callback('on_predict_start', partial(on_predict_start, persist=persist)) 66 | model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end) 67 | -------------------------------------------------------------------------------- /ultralytics/trackers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /ultralytics/trackers/utils/matching.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import numpy as np 4 | import scipy 5 | from scipy.spatial.distance import cdist 6 | 7 | from ultralytics.utils.metrics import bbox_ioa 8 | 9 | try: 10 | import lap # for linear_assignment 11 | 12 | assert lap.__version__ # verify package is not directory 13 | except (ImportError, AssertionError, AttributeError): 14 | from ultralytics.utils.checks import check_requirements 15 | 16 | check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx 17 | import lap 18 | 19 | 20 | def linear_assignment(cost_matrix, thresh, use_lap=True): 21 | """ 22 | Perform linear assignment using scipy or lap.lapjv. 23 | 24 | Args: 25 | cost_matrix (np.ndarray): The matrix containing cost values for assignments. 26 | thresh (float): Threshold for considering an assignment valid. 27 | use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True. 28 | 29 | Returns: 30 | (tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'. 31 | """ 32 | 33 | if cost_matrix.size == 0: 34 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) 35 | 36 | if use_lap: 37 | _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) 38 | matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] 39 | unmatched_a = np.where(x < 0)[0] 40 | unmatched_b = np.where(y < 0)[0] 41 | else: 42 | # Scipy linear sum assignment is NOT working correctly, DO NOT USE 43 | y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x 44 | matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh]) 45 | unmatched = np.ones(cost_matrix.shape) 46 | for i, xi in matches: 47 | unmatched[i, xi] = 0.0 48 | unmatched_a = np.where(unmatched.all(1))[0] 49 | unmatched_b = np.where(unmatched.all(0))[0] 50 | 51 | return matches, unmatched_a, unmatched_b 52 | 53 | 54 | def iou_distance(atracks, btracks): 55 | """ 56 | Compute cost based on Intersection over Union (IoU) between tracks. 57 | 58 | Args: 59 | atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes. 60 | btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes. 61 | 62 | Returns: 63 | (np.ndarray): Cost matrix computed based on IoU. 64 | """ 65 | 66 | if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \ 67 | or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): 68 | atlbrs = atracks 69 | btlbrs = btracks 70 | else: 71 | atlbrs = [track.tlbr for track in atracks] 72 | btlbrs = [track.tlbr for track in btracks] 73 | 74 | ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) 75 | if len(atlbrs) and len(btlbrs): 76 | ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32), 77 | np.ascontiguousarray(btlbrs, dtype=np.float32), 78 | iou=True) 79 | return 1 - ious # cost matrix 80 | 81 | 82 | def embedding_distance(tracks, detections, metric='cosine'): 83 | """ 84 | Compute distance between tracks and detections based on embeddings. 85 | 86 | Args: 87 | tracks (list[STrack]): List of tracks. 88 | detections (list[BaseTrack]): List of detections. 89 | metric (str, optional): Metric for distance computation. Defaults to 'cosine'. 90 | 91 | Returns: 92 | (np.ndarray): Cost matrix computed based on embeddings. 93 | """ 94 | 95 | cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) 96 | if cost_matrix.size == 0: 97 | return cost_matrix 98 | det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) 99 | # for i, track in enumerate(tracks): 100 | # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) 101 | track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) 102 | cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features 103 | return cost_matrix 104 | 105 | 106 | def fuse_score(cost_matrix, detections): 107 | """ 108 | Fuses cost matrix with detection scores to produce a single similarity matrix. 109 | 110 | Args: 111 | cost_matrix (np.ndarray): The matrix containing cost values for assignments. 112 | detections (list[BaseTrack]): List of detections with scores. 113 | 114 | Returns: 115 | (np.ndarray): Fused similarity matrix. 116 | """ 117 | 118 | if cost_matrix.size == 0: 119 | return cost_matrix 120 | iou_sim = 1 - cost_matrix 121 | det_scores = np.array([det.score for det in detections]) 122 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) 123 | fuse_sim = iou_sim * det_scores 124 | return 1 - fuse_sim # fuse_cost 125 | -------------------------------------------------------------------------------- /ultralytics/utils/autobatch.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch. 4 | """ 5 | 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr 12 | from ultralytics.utils.torch_utils import profile 13 | 14 | 15 | def check_train_batch_size(model, imgsz=640, amp=True): 16 | """ 17 | Check YOLO training batch size using the autobatch() function. 18 | 19 | Args: 20 | model (torch.nn.Module): YOLO model to check batch size for. 21 | imgsz (int): Image size used for training. 22 | amp (bool): If True, use automatic mixed precision (AMP) for training. 23 | 24 | Returns: 25 | (int): Optimal batch size computed using the autobatch() function. 26 | """ 27 | 28 | with torch.cuda.amp.autocast(amp): 29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size 30 | 31 | 32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=DEFAULT_CFG.batch): 33 | """ 34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. 35 | 36 | Args: 37 | model (torch.nn.module): YOLO model to compute batch size for. 38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. 39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67. 40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. 41 | 42 | Returns: 43 | (int): The optimal batch size. 44 | """ 45 | 46 | # Check device 47 | prefix = colorstr('AutoBatch: ') 48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}') 49 | device = next(model.parameters()).device # get model device 50 | if device.type == 'cpu': 51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') 52 | return batch_size 53 | if torch.backends.cudnn.benchmark: 54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') 55 | return batch_size 56 | 57 | # Inspect CUDA memory 58 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 59 | d = str(device).upper() # 'CUDA:0' 60 | properties = torch.cuda.get_device_properties(device) # device properties 61 | t = properties.total_memory / gb # GiB total 62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved 63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated 64 | f = t - (r + a) # GiB free 65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') 66 | 67 | # Profile batch sizes 68 | batch_sizes = [1, 2, 4, 8, 16] 69 | try: 70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] 71 | results = profile(img, model, n=3, device=device) 72 | 73 | # Fit a solution 74 | y = [x[2] for x in results if x] # memory [2] 75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit 76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) 77 | if None in results: # some sizes failed 78 | i = results.index(None) # first fail index 79 | if b >= batch_sizes[i]: # y intercept above failure point 80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point 81 | if b < 1 or b > 1024: # b outside of safe range 82 | b = batch_size 83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') 84 | 85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted 86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') 87 | return b 88 | except Exception as e: 89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') 90 | return batch_size 91 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .base import add_integration_callbacks, default_callbacks, get_default_callbacks 4 | 5 | __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks' 6 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/base.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Base callbacks 4 | """ 5 | 6 | from collections import defaultdict 7 | from copy import deepcopy 8 | 9 | # Trainer callbacks ---------------------------------------------------------------------------------------------------- 10 | 11 | 12 | def on_pretrain_routine_start(trainer): 13 | """Called before the pretraining routine starts.""" 14 | pass 15 | 16 | 17 | def on_pretrain_routine_end(trainer): 18 | """Called after the pretraining routine ends.""" 19 | pass 20 | 21 | 22 | def on_train_start(trainer): 23 | """Called when the training starts.""" 24 | pass 25 | 26 | 27 | def on_train_epoch_start(trainer): 28 | """Called at the start of each training epoch.""" 29 | pass 30 | 31 | 32 | def on_train_batch_start(trainer): 33 | """Called at the start of each training batch.""" 34 | pass 35 | 36 | 37 | def optimizer_step(trainer): 38 | """Called when the optimizer takes a step.""" 39 | pass 40 | 41 | 42 | def on_before_zero_grad(trainer): 43 | """Called before the gradients are set to zero.""" 44 | pass 45 | 46 | 47 | def on_train_batch_end(trainer): 48 | """Called at the end of each training batch.""" 49 | pass 50 | 51 | 52 | def on_train_epoch_end(trainer): 53 | """Called at the end of each training epoch.""" 54 | pass 55 | 56 | 57 | def on_fit_epoch_end(trainer): 58 | """Called at the end of each fit epoch (train + val).""" 59 | pass 60 | 61 | 62 | def on_model_save(trainer): 63 | """Called when the model is saved.""" 64 | pass 65 | 66 | 67 | def on_train_end(trainer): 68 | """Called when the training ends.""" 69 | pass 70 | 71 | 72 | def on_params_update(trainer): 73 | """Called when the model parameters are updated.""" 74 | pass 75 | 76 | 77 | def teardown(trainer): 78 | """Called during the teardown of the training process.""" 79 | pass 80 | 81 | 82 | # Validator callbacks -------------------------------------------------------------------------------------------------- 83 | 84 | 85 | def on_val_start(validator): 86 | """Called when the validation starts.""" 87 | pass 88 | 89 | 90 | def on_val_batch_start(validator): 91 | """Called at the start of each validation batch.""" 92 | pass 93 | 94 | 95 | def on_val_batch_end(validator): 96 | """Called at the end of each validation batch.""" 97 | pass 98 | 99 | 100 | def on_val_end(validator): 101 | """Called when the validation ends.""" 102 | pass 103 | 104 | 105 | # Predictor callbacks -------------------------------------------------------------------------------------------------- 106 | 107 | 108 | def on_predict_start(predictor): 109 | """Called when the prediction starts.""" 110 | pass 111 | 112 | 113 | def on_predict_batch_start(predictor): 114 | """Called at the start of each prediction batch.""" 115 | pass 116 | 117 | 118 | def on_predict_batch_end(predictor): 119 | """Called at the end of each prediction batch.""" 120 | pass 121 | 122 | 123 | def on_predict_postprocess_end(predictor): 124 | """Called after the post-processing of the prediction ends.""" 125 | pass 126 | 127 | 128 | def on_predict_end(predictor): 129 | """Called when the prediction ends.""" 130 | pass 131 | 132 | 133 | # Exporter callbacks --------------------------------------------------------------------------------------------------- 134 | 135 | 136 | def on_export_start(exporter): 137 | """Called when the model export starts.""" 138 | pass 139 | 140 | 141 | def on_export_end(exporter): 142 | """Called when the model export ends.""" 143 | pass 144 | 145 | 146 | default_callbacks = { 147 | # Run in trainer 148 | 'on_pretrain_routine_start': [on_pretrain_routine_start], 149 | 'on_pretrain_routine_end': [on_pretrain_routine_end], 150 | 'on_train_start': [on_train_start], 151 | 'on_train_epoch_start': [on_train_epoch_start], 152 | 'on_train_batch_start': [on_train_batch_start], 153 | 'optimizer_step': [optimizer_step], 154 | 'on_before_zero_grad': [on_before_zero_grad], 155 | 'on_train_batch_end': [on_train_batch_end], 156 | 'on_train_epoch_end': [on_train_epoch_end], 157 | 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val 158 | 'on_model_save': [on_model_save], 159 | 'on_train_end': [on_train_end], 160 | 'on_params_update': [on_params_update], 161 | 'teardown': [teardown], 162 | 163 | # Run in validator 164 | 'on_val_start': [on_val_start], 165 | 'on_val_batch_start': [on_val_batch_start], 166 | 'on_val_batch_end': [on_val_batch_end], 167 | 'on_val_end': [on_val_end], 168 | 169 | # Run in predictor 170 | 'on_predict_start': [on_predict_start], 171 | 'on_predict_batch_start': [on_predict_batch_start], 172 | 'on_predict_postprocess_end': [on_predict_postprocess_end], 173 | 'on_predict_batch_end': [on_predict_batch_end], 174 | 'on_predict_end': [on_predict_end], 175 | 176 | # Run in exporter 177 | 'on_export_start': [on_export_start], 178 | 'on_export_end': [on_export_end]} 179 | 180 | 181 | def get_default_callbacks(): 182 | """ 183 | Return a copy of the default_callbacks dictionary with lists as default values. 184 | 185 | Returns: 186 | (defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values. 187 | """ 188 | return defaultdict(list, deepcopy(default_callbacks)) 189 | 190 | 191 | def add_integration_callbacks(instance): 192 | """ 193 | Add integration callbacks from various sources to the instance's callbacks. 194 | 195 | Args: 196 | instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary 197 | of callback lists. 198 | """ 199 | from .clearml import callbacks as clearml_cb 200 | from .comet import callbacks as comet_cb 201 | from .dvc import callbacks as dvc_cb 202 | from .hub import callbacks as hub_cb 203 | from .mlflow import callbacks as mlflow_cb 204 | from .neptune import callbacks as neptune_cb 205 | from .raytune import callbacks as tune_cb 206 | from .tensorboard import callbacks as tensorboard_cb 207 | from .wb import callbacks as wb_cb 208 | 209 | for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb: 210 | for k, v in x.items(): 211 | if v not in instance.callbacks[k]: # prevent duplicate callbacks addition 212 | instance.callbacks[k].append(v) # callback[name].append(func) 213 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/clearml.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import re 4 | 5 | import matplotlib.image as mpimg 6 | import matplotlib.pyplot as plt 7 | 8 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING 9 | from ultralytics.utils.torch_utils import model_info_for_loggers 10 | 11 | try: 12 | import clearml 13 | from clearml import Task 14 | from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO 15 | from clearml.binding.matplotlib_bind import PatchedMatplotlib 16 | 17 | assert hasattr(clearml, '__version__') # verify package is not directory 18 | assert not TESTS_RUNNING # do not log pytest 19 | assert SETTINGS['clearml'] is True # verify integration is enabled 20 | except (ImportError, AssertionError): 21 | clearml = None 22 | 23 | 24 | def _log_debug_samples(files, title='Debug Samples') -> None: 25 | """ 26 | Log files (images) as debug samples in the ClearML task. 27 | 28 | Args: 29 | files (list): A list of file paths in PosixPath format. 30 | title (str): A title that groups together images with the same values. 31 | """ 32 | task = Task.current_task() 33 | if task: 34 | for f in files: 35 | if f.exists(): 36 | it = re.search(r'_batch(\d+)', f.name) 37 | iteration = int(it.groups()[0]) if it else 0 38 | task.get_logger().report_image(title=title, 39 | series=f.name.replace(it.group(), ''), 40 | local_path=str(f), 41 | iteration=iteration) 42 | 43 | 44 | def _log_plot(title, plot_path) -> None: 45 | """ 46 | Log an image as a plot in the plot section of ClearML. 47 | 48 | Args: 49 | title (str): The title of the plot. 50 | plot_path (str): The path to the saved image file. 51 | """ 52 | img = mpimg.imread(plot_path) 53 | fig = plt.figure() 54 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks 55 | ax.imshow(img) 56 | 57 | Task.current_task().get_logger().report_matplotlib_figure(title=title, 58 | series='', 59 | figure=fig, 60 | report_interactive=False) 61 | 62 | 63 | def on_pretrain_routine_start(trainer): 64 | """Runs at start of pretraining routine; initializes and connects/ logs task to ClearML.""" 65 | try: 66 | task = Task.current_task() 67 | if task: 68 | # Make sure the automatic pytorch and matplotlib bindings are disabled! 69 | # We are logging these plots and model files manually in the integration 70 | PatchPyTorchModelIO.update_current_task(None) 71 | PatchedMatplotlib.update_current_task(None) 72 | else: 73 | task = Task.init(project_name=trainer.args.project or 'YOLOv8', 74 | task_name=trainer.args.name, 75 | tags=['YOLOv8'], 76 | output_uri=True, 77 | reuse_last_task_id=False, 78 | auto_connect_frameworks={ 79 | 'pytorch': False, 80 | 'matplotlib': False}) 81 | LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, ' 82 | 'please add clearml-init and connect your arguments before initializing YOLO.') 83 | task.connect(vars(trainer.args), name='General') 84 | except Exception as e: 85 | LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}') 86 | 87 | 88 | def on_train_epoch_end(trainer): 89 | task = Task.current_task() 90 | 91 | if task: 92 | """Logs debug samples for the first epoch of YOLO training.""" 93 | if trainer.epoch == 1: 94 | _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic') 95 | """Report the current training progress.""" 96 | for k, v in trainer.validator.metrics.results_dict.items(): 97 | task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch) 98 | 99 | 100 | def on_fit_epoch_end(trainer): 101 | """Reports model information to logger at the end of an epoch.""" 102 | task = Task.current_task() 103 | if task: 104 | # You should have access to the validation bboxes under jdict 105 | task.get_logger().report_scalar(title='Epoch Time', 106 | series='Epoch Time', 107 | value=trainer.epoch_time, 108 | iteration=trainer.epoch) 109 | if trainer.epoch == 0: 110 | for k, v in model_info_for_loggers(trainer).items(): 111 | task.get_logger().report_single_value(k, v) 112 | 113 | 114 | def on_val_end(validator): 115 | """Logs validation results including labels and predictions.""" 116 | if Task.current_task(): 117 | # Log val_labels and val_pred 118 | _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation') 119 | 120 | 121 | def on_train_end(trainer): 122 | """Logs final model and its name on training completion.""" 123 | task = Task.current_task() 124 | if task: 125 | # Log final results, CM matrix + PR plots 126 | files = [ 127 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', 128 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] 129 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter 130 | for f in files: 131 | _log_plot(title=f.stem, plot_path=f) 132 | # Report final metrics 133 | for k, v in trainer.validator.metrics.results_dict.items(): 134 | task.get_logger().report_single_value(k, v) 135 | # Log the final model 136 | task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) 137 | 138 | 139 | callbacks = { 140 | 'on_pretrain_routine_start': on_pretrain_routine_start, 141 | 'on_train_epoch_end': on_train_epoch_end, 142 | 'on_fit_epoch_end': on_fit_epoch_end, 143 | 'on_val_end': on_val_end, 144 | 'on_train_end': on_train_end} if clearml else {} 145 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/dvc.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | from pathlib import Path 6 | 7 | import pkg_resources as pkg 8 | 9 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING 10 | from ultralytics.utils.torch_utils import model_info_for_loggers 11 | 12 | try: 13 | from importlib.metadata import version 14 | 15 | import dvclive 16 | 17 | assert not TESTS_RUNNING # do not log pytest 18 | assert SETTINGS['dvc'] is True # verify integration is enabled 19 | 20 | ver = version('dvclive') 21 | if pkg.parse_version(ver) < pkg.parse_version('2.11.0'): 22 | LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).') 23 | dvclive = None # noqa: F811 24 | except (ImportError, AssertionError, TypeError): 25 | dvclive = None 26 | 27 | # DVCLive logger instance 28 | live = None 29 | _processed_plots = {} 30 | 31 | # `on_fit_epoch_end` is called on final validation (probably need to be fixed) 32 | # for now this is the way we distinguish final evaluation of the best model vs 33 | # last epoch validation 34 | _training_epoch = False 35 | 36 | 37 | def _log_images(path, prefix=''): 38 | if live: 39 | name = path.name 40 | 41 | # Group images by batch to enable sliders in UI 42 | if m := re.search(r'_batch(\d+)', name): 43 | ni = m.group(1) 44 | new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem) 45 | name = (Path(new_stem) / ni).with_suffix(path.suffix) 46 | 47 | live.log_image(os.path.join(prefix, name), path) 48 | 49 | 50 | def _log_plots(plots, prefix=''): 51 | for name, params in plots.items(): 52 | timestamp = params['timestamp'] 53 | if _processed_plots.get(name) != timestamp: 54 | _log_images(name, prefix) 55 | _processed_plots[name] = timestamp 56 | 57 | 58 | def _log_confusion_matrix(validator): 59 | targets = [] 60 | preds = [] 61 | matrix = validator.confusion_matrix.matrix 62 | names = list(validator.names.values()) 63 | if validator.confusion_matrix.task == 'detect': 64 | names += ['background'] 65 | 66 | for ti, pred in enumerate(matrix.T.astype(int)): 67 | for pi, num in enumerate(pred): 68 | targets.extend([names[ti]] * num) 69 | preds.extend([names[pi]] * num) 70 | 71 | live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True) 72 | 73 | 74 | def on_pretrain_routine_start(trainer): 75 | try: 76 | global live 77 | live = dvclive.Live(save_dvc_exp=True, cache_images=True) 78 | LOGGER.info( 79 | f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).' 80 | ) 81 | except Exception as e: 82 | LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}') 83 | 84 | 85 | def on_pretrain_routine_end(trainer): 86 | _log_plots(trainer.plots, 'train') 87 | 88 | 89 | def on_train_start(trainer): 90 | if live: 91 | live.log_params(trainer.args) 92 | 93 | 94 | def on_train_epoch_start(trainer): 95 | global _training_epoch 96 | _training_epoch = True 97 | 98 | 99 | def on_fit_epoch_end(trainer): 100 | global _training_epoch 101 | if live and _training_epoch: 102 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} 103 | for metric, value in all_metrics.items(): 104 | live.log_metric(metric, value) 105 | 106 | if trainer.epoch == 0: 107 | for metric, value in model_info_for_loggers(trainer).items(): 108 | live.log_metric(metric, value, plot=False) 109 | 110 | _log_plots(trainer.plots, 'train') 111 | _log_plots(trainer.validator.plots, 'val') 112 | 113 | live.next_step() 114 | _training_epoch = False 115 | 116 | 117 | def on_train_end(trainer): 118 | if live: 119 | # At the end log the best metrics. It runs validator on the best model internally. 120 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} 121 | for metric, value in all_metrics.items(): 122 | live.log_metric(metric, value, plot=False) 123 | 124 | _log_plots(trainer.plots, 'val') 125 | _log_plots(trainer.validator.plots, 'val') 126 | _log_confusion_matrix(trainer.validator) 127 | 128 | if trainer.best.exists(): 129 | live.log_artifact(trainer.best, copy=True, type='model') 130 | 131 | live.end() 132 | 133 | 134 | callbacks = { 135 | 'on_pretrain_routine_start': on_pretrain_routine_start, 136 | 'on_pretrain_routine_end': on_pretrain_routine_end, 137 | 'on_train_start': on_train_start, 138 | 'on_train_epoch_start': on_train_epoch_start, 139 | 'on_fit_epoch_end': on_fit_epoch_end, 140 | 'on_train_end': on_train_end} if dvclive else {} 141 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/hub.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import json 4 | from time import time 5 | 6 | from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events 7 | from ultralytics.utils import LOGGER, SETTINGS 8 | from ultralytics.utils.torch_utils import model_info_for_loggers 9 | 10 | 11 | def on_pretrain_routine_end(trainer): 12 | """Logs info before starting timer for upload rate limit.""" 13 | session = getattr(trainer, 'hub_session', None) 14 | if session: 15 | # Start timer for upload rate limit 16 | LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀') 17 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit 18 | 19 | 20 | def on_fit_epoch_end(trainer): 21 | """Uploads training progress metrics at the end of each epoch.""" 22 | session = getattr(trainer, 'hub_session', None) 23 | if session: 24 | # Upload metrics after val end 25 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} 26 | if trainer.epoch == 0: 27 | all_plots = {**all_plots, **model_info_for_loggers(trainer)} 28 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots) 29 | if time() - session.timers['metrics'] > session.rate_limits['metrics']: 30 | session.upload_metrics() 31 | session.timers['metrics'] = time() # reset timer 32 | session.metrics_queue = {} # reset queue 33 | 34 | 35 | def on_model_save(trainer): 36 | """Saves checkpoints to Ultralytics HUB with rate limiting.""" 37 | session = getattr(trainer, 'hub_session', None) 38 | if session: 39 | # Upload checkpoints with rate limiting 40 | is_best = trainer.best_fitness == trainer.fitness 41 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: 42 | LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}') 43 | session.upload_model(trainer.epoch, trainer.last, is_best) 44 | session.timers['ckpt'] = time() # reset timer 45 | 46 | 47 | def on_train_end(trainer): 48 | """Upload final model and metrics to Ultralytics HUB at the end of training.""" 49 | session = getattr(trainer, 'hub_session', None) 50 | if session: 51 | # Upload final model and metrics with exponential standoff 52 | LOGGER.info(f'{PREFIX}Syncing final model...') 53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) 54 | session.alive = False # stop heartbeats 55 | LOGGER.info(f'{PREFIX}Done ✅\n' 56 | f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀') 57 | 58 | 59 | def on_train_start(trainer): 60 | """Run events on train start.""" 61 | events(trainer.args) 62 | 63 | 64 | def on_val_start(validator): 65 | """Runs events on validation start.""" 66 | events(validator.args) 67 | 68 | 69 | def on_predict_start(predictor): 70 | """Run events on predict start.""" 71 | events(predictor.args) 72 | 73 | 74 | def on_export_start(exporter): 75 | """Run events on export start.""" 76 | events(exporter.args) 77 | 78 | 79 | callbacks = { 80 | 'on_pretrain_routine_end': on_pretrain_routine_end, 81 | 'on_fit_epoch_end': on_fit_epoch_end, 82 | 'on_model_save': on_model_save, 83 | 'on_train_end': on_train_end, 84 | 'on_train_start': on_train_start, 85 | 'on_val_start': on_val_start, 86 | 'on_predict_start': on_predict_start, 87 | 'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled 88 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/mlflow.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | from pathlib import Path 6 | 7 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr 8 | 9 | try: 10 | import mlflow 11 | 12 | assert not TESTS_RUNNING # do not log pytest 13 | assert hasattr(mlflow, '__version__') # verify package is not directory 14 | assert SETTINGS['mlflow'] is True # verify integration is enabled 15 | except (ImportError, AssertionError): 16 | mlflow = None 17 | 18 | 19 | def on_pretrain_routine_end(trainer): 20 | """Logs training parameters to MLflow.""" 21 | global mlflow, run, experiment_name 22 | 23 | if os.environ.get('MLFLOW_TRACKING_URI') is None: 24 | mlflow = None 25 | 26 | if mlflow: 27 | mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000" 28 | mlflow.set_tracking_uri(mlflow_location) 29 | 30 | experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8' 31 | run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name 32 | experiment = mlflow.get_experiment_by_name(experiment_name) 33 | if experiment is None: 34 | mlflow.create_experiment(experiment_name) 35 | mlflow.set_experiment(experiment_name) 36 | 37 | prefix = colorstr('MLFlow: ') 38 | try: 39 | run, active_run = mlflow, mlflow.active_run() 40 | if not active_run: 41 | active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name) 42 | LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}') 43 | run.log_params(vars(trainer.model.args)) 44 | except Exception as err: 45 | LOGGER.error(f'{prefix}Failing init - {repr(err)}') 46 | LOGGER.warning(f'{prefix}Continuing without Mlflow') 47 | 48 | 49 | def on_fit_epoch_end(trainer): 50 | """Logs training metrics to Mlflow.""" 51 | if mlflow: 52 | metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()} 53 | run.log_metrics(metrics=metrics_dict, step=trainer.epoch) 54 | 55 | 56 | def on_train_end(trainer): 57 | """Called at end of train loop to log model artifact info.""" 58 | if mlflow: 59 | root_dir = Path(__file__).resolve().parents[3] 60 | run.log_artifact(trainer.last) 61 | run.log_artifact(trainer.best) 62 | run.pyfunc.log_model(artifact_path=experiment_name, 63 | code_path=[str(root_dir)], 64 | artifacts={'model_path': str(trainer.save_dir)}, 65 | python_model=run.pyfunc.PythonModel()) 66 | 67 | 68 | callbacks = { 69 | 'on_pretrain_routine_end': on_pretrain_routine_end, 70 | 'on_fit_epoch_end': on_fit_epoch_end, 71 | 'on_train_end': on_train_end} if mlflow else {} 72 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/neptune.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import matplotlib.image as mpimg 4 | import matplotlib.pyplot as plt 5 | 6 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING 7 | from ultralytics.utils.torch_utils import model_info_for_loggers 8 | 9 | try: 10 | import neptune 11 | from neptune.types import File 12 | 13 | assert not TESTS_RUNNING # do not log pytest 14 | assert hasattr(neptune, '__version__') 15 | assert SETTINGS['neptune'] is True # verify integration is enabled 16 | except (ImportError, AssertionError): 17 | neptune = None 18 | 19 | run = None # NeptuneAI experiment logger instance 20 | 21 | 22 | def _log_scalars(scalars, step=0): 23 | """Log scalars to the NeptuneAI experiment logger.""" 24 | if run: 25 | for k, v in scalars.items(): 26 | run[k].append(value=v, step=step) 27 | 28 | 29 | def _log_images(imgs_dict, group=''): 30 | """Log scalars to the NeptuneAI experiment logger.""" 31 | if run: 32 | for k, v in imgs_dict.items(): 33 | run[f'{group}/{k}'].upload(File(v)) 34 | 35 | 36 | def _log_plot(title, plot_path): 37 | """Log plots to the NeptuneAI experiment logger.""" 38 | """ 39 | Log image as plot in the plot section of NeptuneAI 40 | 41 | arguments: 42 | title (str) Title of the plot 43 | plot_path (PosixPath or str) Path to the saved image file 44 | """ 45 | img = mpimg.imread(plot_path) 46 | fig = plt.figure() 47 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks 48 | ax.imshow(img) 49 | run[f'Plots/{title}'].upload(fig) 50 | 51 | 52 | def on_pretrain_routine_start(trainer): 53 | """Callback function called before the training routine starts.""" 54 | try: 55 | global run 56 | run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8']) 57 | run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()} 58 | except Exception as e: 59 | LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}') 60 | 61 | 62 | def on_train_epoch_end(trainer): 63 | """Callback function called at end of each training epoch.""" 64 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 65 | _log_scalars(trainer.lr, trainer.epoch + 1) 66 | if trainer.epoch == 1: 67 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic') 68 | 69 | 70 | def on_fit_epoch_end(trainer): 71 | """Callback function called at end of each fit (train+val) epoch.""" 72 | if run and trainer.epoch == 0: 73 | run['Configuration/Model'] = model_info_for_loggers(trainer) 74 | _log_scalars(trainer.metrics, trainer.epoch + 1) 75 | 76 | 77 | def on_val_end(validator): 78 | """Callback function called at end of each validation.""" 79 | if run: 80 | # Log val_labels and val_pred 81 | _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation') 82 | 83 | 84 | def on_train_end(trainer): 85 | """Callback function called at end of training.""" 86 | if run: 87 | # Log final results, CM matrix + PR plots 88 | files = [ 89 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', 90 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] 91 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter 92 | for f in files: 93 | _log_plot(title=f.stem, plot_path=f) 94 | # Log the final model 95 | run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str( 96 | trainer.best))) 97 | 98 | 99 | callbacks = { 100 | 'on_pretrain_routine_start': on_pretrain_routine_start, 101 | 'on_train_epoch_end': on_train_epoch_end, 102 | 'on_fit_epoch_end': on_fit_epoch_end, 103 | 'on_val_end': on_val_end, 104 | 'on_train_end': on_train_end} if neptune else {} 105 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/raytune.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.utils import SETTINGS 4 | 5 | try: 6 | import ray 7 | from ray import tune 8 | from ray.air import session 9 | 10 | assert SETTINGS['raytune'] is True # verify integration is enabled 11 | except (ImportError, AssertionError): 12 | tune = None 13 | 14 | 15 | def on_fit_epoch_end(trainer): 16 | """Sends training metrics to Ray Tune at end of each epoch.""" 17 | if ray.tune.is_session_enabled(): 18 | metrics = trainer.metrics 19 | metrics['epoch'] = trainer.epoch 20 | session.report(metrics) 21 | 22 | 23 | callbacks = { 24 | 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} 25 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr 4 | 5 | try: 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | assert not TESTS_RUNNING # do not log pytest 9 | assert SETTINGS['tensorboard'] is True # verify integration is enabled 10 | 11 | # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows 12 | except (ImportError, AssertionError, TypeError): 13 | SummaryWriter = None 14 | 15 | writer = None # TensorBoard SummaryWriter instance 16 | 17 | 18 | def _log_scalars(scalars, step=0): 19 | """Logs scalar values to TensorBoard.""" 20 | if writer: 21 | for k, v in scalars.items(): 22 | writer.add_scalar(k, v, step) 23 | 24 | 25 | def on_pretrain_routine_start(trainer): 26 | """Initialize TensorBoard logging with SummaryWriter.""" 27 | if SummaryWriter: 28 | try: 29 | global writer 30 | writer = SummaryWriter(str(trainer.save_dir)) 31 | prefix = colorstr('TensorBoard: ') 32 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") 33 | except Exception as e: 34 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') 35 | 36 | 37 | def on_batch_end(trainer): 38 | """Logs scalar statistics at the end of a training batch.""" 39 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 40 | 41 | 42 | def on_fit_epoch_end(trainer): 43 | """Logs epoch metrics at end of training epoch.""" 44 | _log_scalars(trainer.metrics, trainer.epoch + 1) 45 | 46 | 47 | callbacks = { 48 | 'on_pretrain_routine_start': on_pretrain_routine_start, 49 | 'on_fit_epoch_end': on_fit_epoch_end, 50 | 'on_batch_end': on_batch_end} 51 | -------------------------------------------------------------------------------- /ultralytics/utils/callbacks/wb.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.utils import SETTINGS, TESTS_RUNNING 4 | from ultralytics.utils.torch_utils import model_info_for_loggers 5 | 6 | try: 7 | import wandb as wb 8 | 9 | assert hasattr(wb, '__version__') 10 | assert not TESTS_RUNNING # do not log pytest 11 | assert SETTINGS['wandb'] is True # verify integration is enabled 12 | except (ImportError, AssertionError): 13 | wb = None 14 | 15 | _processed_plots = {} 16 | 17 | 18 | def _log_plots(plots, step): 19 | for name, params in plots.items(): 20 | timestamp = params['timestamp'] 21 | if _processed_plots.get(name) != timestamp: 22 | wb.run.log({name.stem: wb.Image(str(name))}, step=step) 23 | _processed_plots[name] = timestamp 24 | 25 | 26 | def on_pretrain_routine_start(trainer): 27 | """Initiate and start project if module is present.""" 28 | wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args)) 29 | 30 | 31 | def on_fit_epoch_end(trainer): 32 | """Logs training metrics and model information at the end of an epoch.""" 33 | wb.run.log(trainer.metrics, step=trainer.epoch + 1) 34 | _log_plots(trainer.plots, step=trainer.epoch + 1) 35 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1) 36 | if trainer.epoch == 0: 37 | wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) 38 | 39 | 40 | def on_train_epoch_end(trainer): 41 | """Log metrics and save images at the end of each training epoch.""" 42 | wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) 43 | wb.run.log(trainer.lr, step=trainer.epoch + 1) 44 | if trainer.epoch == 1: 45 | _log_plots(trainer.plots, step=trainer.epoch + 1) 46 | 47 | 48 | def on_train_end(trainer): 49 | """Save the best model as an artifact at end of training.""" 50 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1) 51 | _log_plots(trainer.plots, step=trainer.epoch + 1) 52 | art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') 53 | if trainer.best.exists(): 54 | art.add_file(trainer.best) 55 | wb.run.log_artifact(art, aliases=['best']) 56 | 57 | 58 | callbacks = { 59 | 'on_pretrain_routine_start': on_pretrain_routine_start, 60 | 'on_train_epoch_end': on_train_epoch_end, 61 | 'on_fit_epoch_end': on_fit_epoch_end, 62 | 'on_train_end': on_train_end} if wb else {} 63 | -------------------------------------------------------------------------------- /ultralytics/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | import shutil 6 | import socket 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | from . import USER_CONFIG_DIR 12 | from .torch_utils import TORCH_1_9 13 | 14 | 15 | def find_free_network_port() -> int: 16 | """Finds a free port on localhost. 17 | 18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the 19 | `MASTER_PORT` environment variable. 20 | """ 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 22 | s.bind(('127.0.0.1', 0)) 23 | return s.getsockname()[1] # port 24 | 25 | 26 | def generate_ddp_file(trainer): 27 | """Generates a DDP file and returns its file name.""" 28 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) 29 | 30 | content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__": 31 | from {module} import {name} 32 | from ultralytics.utils import DEFAULT_CFG_DICT 33 | 34 | cfg = DEFAULT_CFG_DICT.copy() 35 | cfg.update(save_dir='') # handle the extra key 'save_dir' 36 | trainer = {name}(cfg=cfg, overrides=overrides) 37 | trainer.train()''' 38 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) 39 | with tempfile.NamedTemporaryFile(prefix='_temp_', 40 | suffix=f'{id(trainer)}.py', 41 | mode='w+', 42 | encoding='utf-8', 43 | dir=USER_CONFIG_DIR / 'DDP', 44 | delete=False) as file: 45 | file.write(content) 46 | return file.name 47 | 48 | 49 | def generate_ddp_command(world_size, trainer): 50 | """Generates and returns command for distributed training.""" 51 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 52 | if not trainer.resume: 53 | shutil.rmtree(trainer.save_dir) # remove the save_dir 54 | file = str(Path(sys.argv[0]).resolve()) 55 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters 56 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI 57 | file = generate_ddp_file(trainer) 58 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' 59 | port = find_free_network_port() 60 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] 61 | return cmd, file 62 | 63 | 64 | def ddp_cleanup(trainer, file): 65 | """Delete temp file if created.""" 66 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file 67 | os.remove(file) 68 | -------------------------------------------------------------------------------- /ultralytics/utils/errors.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.utils import emojis 4 | 5 | 6 | class HUBModelError(Exception): 7 | 8 | def __init__(self, message='Model not found. Please check model URL and try again.'): 9 | """Create an exception for when a model is not found.""" 10 | super().__init__(emojis(message)) 11 | -------------------------------------------------------------------------------- /ultralytics/utils/files.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import contextlib 4 | import glob 5 | import os 6 | import shutil 7 | import tempfile 8 | from contextlib import contextmanager 9 | from datetime import datetime 10 | from pathlib import Path 11 | 12 | 13 | class WorkingDirectory(contextlib.ContextDecorator): 14 | """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.""" 15 | 16 | def __init__(self, new_dir): 17 | """Sets the working directory to 'new_dir' upon instantiation.""" 18 | self.dir = new_dir # new dir 19 | self.cwd = Path.cwd().resolve() # current dir 20 | 21 | def __enter__(self): 22 | """Changes the current directory to the specified directory.""" 23 | os.chdir(self.dir) 24 | 25 | def __exit__(self, exc_type, exc_val, exc_tb): # noqa 26 | """Restore the current working directory on context exit.""" 27 | os.chdir(self.cwd) 28 | 29 | 30 | @contextmanager 31 | def spaces_in_path(path): 32 | """ 33 | Context manager to handle paths with spaces in their names. 34 | If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, 35 | executes the context code block, then copies the file/directory back to its original location. 36 | 37 | Args: 38 | path (str | Path): The original path. 39 | 40 | Yields: 41 | (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path. 42 | 43 | Example: 44 | ```python 45 | with spaces_in_path('/path/with spaces') as new_path: 46 | # your code here 47 | ``` 48 | """ 49 | 50 | # If path has spaces, replace them with underscores 51 | if ' ' in str(path): 52 | string = isinstance(path, str) # input type 53 | path = Path(path) 54 | 55 | # Create a temporary directory and construct the new path 56 | with tempfile.TemporaryDirectory() as tmp_dir: 57 | tmp_path = Path(tmp_dir) / path.name.replace(' ', '_') 58 | 59 | # Copy file/directory 60 | if path.is_dir(): 61 | # tmp_path.mkdir(parents=True, exist_ok=True) 62 | shutil.copytree(path, tmp_path) 63 | elif path.is_file(): 64 | tmp_path.parent.mkdir(parents=True, exist_ok=True) 65 | shutil.copy2(path, tmp_path) 66 | 67 | try: 68 | # Yield the temporary path 69 | yield str(tmp_path) if string else tmp_path 70 | 71 | finally: 72 | # Copy file/directory back 73 | if tmp_path.is_dir(): 74 | shutil.copytree(tmp_path, path, dirs_exist_ok=True) 75 | elif tmp_path.is_file(): 76 | shutil.copy2(tmp_path, path) # Copy back the file 77 | 78 | else: 79 | # If there are no spaces, just yield the original path 80 | yield path 81 | 82 | 83 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 84 | """ 85 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 86 | 87 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to 88 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the 89 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a 90 | directory if it does not already exist. 91 | 92 | Args: 93 | path (str, pathlib.Path): Path to increment. 94 | exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False. 95 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''. 96 | mkdir (bool, optional): Create a directory if it does not exist. Defaults to False. 97 | 98 | Returns: 99 | (pathlib.Path): Incremented path. 100 | """ 101 | path = Path(path) # os-agnostic 102 | if path.exists() and not exist_ok: 103 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') 104 | 105 | # Method 1 106 | for n in range(2, 9999): 107 | p = f'{path}{sep}{n}{suffix}' # increment path 108 | if not os.path.exists(p): # 109 | break 110 | path = Path(p) 111 | 112 | if mkdir: 113 | path.mkdir(parents=True, exist_ok=True) # make directory 114 | 115 | return path 116 | 117 | 118 | def file_age(path=__file__): 119 | """Return days since last file update.""" 120 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta 121 | return dt.days # + dt.seconds / 86400 # fractional days 122 | 123 | 124 | def file_date(path=__file__): 125 | """Return human-readable file modification date, i.e. '2021-3-26'.""" 126 | t = datetime.fromtimestamp(Path(path).stat().st_mtime) 127 | return f'{t.year}-{t.month}-{t.day}' 128 | 129 | 130 | def file_size(path): 131 | """Return file/dir size (MB).""" 132 | if isinstance(path, (str, Path)): 133 | mb = 1 << 20 # bytes to MiB (1024 ** 2) 134 | path = Path(path) 135 | if path.is_file(): 136 | return path.stat().st_size / mb 137 | elif path.is_dir(): 138 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb 139 | return 0.0 140 | 141 | 142 | def get_latest_run(search_dir='.'): 143 | """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" 144 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) 145 | return max(last_list, key=os.path.getctime) if last_list else '' 146 | 147 | 148 | def make_dirs(dir='new_dir/'): 149 | """Create directories.""" 150 | dir = Path(dir) 151 | if dir.exists(): 152 | shutil.rmtree(dir) # delete dir 153 | for p in dir, dir / 'labels', dir / 'images': 154 | p.mkdir(parents=True, exist_ok=True) # make dir 155 | return dir 156 | -------------------------------------------------------------------------------- /ultralytics/utils/patches.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Monkey patches to update/extend functionality of existing functions 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | 12 | # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------ 13 | _imshow = cv2.imshow # copy to avoid recursion errors 14 | 15 | 16 | def imread(filename, flags=cv2.IMREAD_COLOR): 17 | return cv2.imdecode(np.fromfile(filename, np.uint8), flags) 18 | 19 | 20 | def imwrite(filename, img): 21 | try: 22 | cv2.imencode(Path(filename).suffix, img)[1].tofile(filename) 23 | return True 24 | except Exception: 25 | return False 26 | 27 | 28 | def imshow(path, im): 29 | _imshow(path.encode('unicode_escape').decode(), im) 30 | 31 | 32 | # PyTorch functions ---------------------------------------------------------------------------------------------------- 33 | _torch_save = torch.save # copy to avoid recursion errors 34 | 35 | 36 | def torch_save(*args, **kwargs): 37 | """Use dill (if exists) to serialize the lambda functions where pickle does not do this.""" 38 | try: 39 | import dill as pickle 40 | except ImportError: 41 | import pickle 42 | 43 | if 'pickle_module' not in kwargs: 44 | kwargs['pickle_module'] = pickle 45 | return _torch_save(*args, **kwargs) 46 | -------------------------------------------------------------------------------- /ultralytics/utils/tuner.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.cfg import TASK2DATA, TASK2METRIC 4 | from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS 5 | 6 | 7 | def run_ray_tune(model, 8 | space: dict = None, 9 | grace_period: int = 10, 10 | gpu_per_trial: int = None, 11 | max_samples: int = 10, 12 | **train_args): 13 | """ 14 | Runs hyperparameter tuning using Ray Tune. 15 | 16 | Args: 17 | model (YOLO): Model to run the tuner on. 18 | space (dict, optional): The hyperparameter search space. Defaults to None. 19 | grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10. 20 | gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None. 21 | max_samples (int, optional): The maximum number of trials to run. Defaults to 10. 22 | train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}. 23 | 24 | Returns: 25 | (dict): A dictionary containing the results of the hyperparameter search. 26 | 27 | Raises: 28 | ModuleNotFoundError: If Ray Tune is not installed. 29 | """ 30 | if train_args is None: 31 | train_args = {} 32 | 33 | try: 34 | from ray import tune 35 | from ray.air import RunConfig 36 | from ray.air.integrations.wandb import WandbLoggerCallback 37 | from ray.tune.schedulers import ASHAScheduler 38 | except ImportError: 39 | raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"') 40 | 41 | try: 42 | import wandb 43 | 44 | assert hasattr(wandb, '__version__') 45 | except (ImportError, AssertionError): 46 | wandb = False 47 | 48 | default_space = { 49 | # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), 50 | 'lr0': tune.uniform(1e-5, 1e-1), 51 | 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) 52 | 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 53 | 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 54 | 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) 55 | 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum 56 | 'box': tune.uniform(0.02, 0.2), # box loss gain 57 | 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) 58 | 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) 59 | 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) 60 | 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) 61 | 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) 62 | 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) 63 | 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) 64 | 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) 65 | 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 66 | 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) 67 | 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) 68 | 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) 69 | 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) 70 | 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) 71 | 72 | def _tune(config): 73 | """ 74 | Trains the YOLO model with the specified hyperparameters and additional arguments. 75 | 76 | Args: 77 | config (dict): A dictionary of hyperparameters to use for training. 78 | 79 | Returns: 80 | None. 81 | """ 82 | model._reset_callbacks() 83 | config.update(train_args) 84 | model.train(**config) 85 | 86 | # Get search space 87 | if not space: 88 | space = default_space 89 | LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.') 90 | 91 | # Get dataset 92 | data = train_args.get('data', TASK2DATA[model.task]) 93 | space['data'] = data 94 | if 'data' not in train_args: 95 | LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".') 96 | 97 | # Define the trainable function with allocated resources 98 | trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0}) 99 | 100 | # Define the ASHA scheduler for hyperparameter search 101 | asha_scheduler = ASHAScheduler(time_attr='epoch', 102 | metric=TASK2METRIC[model.task], 103 | mode='max', 104 | max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100, 105 | grace_period=grace_period, 106 | reduction_factor=3) 107 | 108 | # Define the callbacks for the hyperparameter search 109 | tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else [] 110 | 111 | # Create the Ray Tune hyperparameter search tuner 112 | tuner = tune.Tuner(trainable_with_resources, 113 | param_space=space, 114 | tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), 115 | run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune')) 116 | 117 | # Run the hyperparameter search 118 | tuner.fit() 119 | 120 | # Return the results of the hyperparameter search 121 | return tuner.get_results() 122 | -------------------------------------------------------------------------------- /ultralytics/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from . import v8 4 | 5 | __all__ = 'v8', # tuple or list 6 | -------------------------------------------------------------------------------- /ultralytics/yolo/cfg/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | from ultralytics.utils import LOGGER 5 | 6 | # Set modules in sys.modules under their old name 7 | sys.modules['ultralytics.yolo.cfg'] = importlib.import_module('ultralytics.cfg') 8 | 9 | LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.cfg' is deprecated since '8.0.136' and will be removed in '8.1.0'. " 10 | "Please use 'ultralytics.cfg' instead.") 11 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | from ultralytics.utils import LOGGER 5 | 6 | # Set modules in sys.modules under their old name 7 | sys.modules['ultralytics.yolo.data'] = importlib.import_module('ultralytics.data') 8 | # This is for updating old cls models, or the way in following warning won't work. 9 | sys.modules['ultralytics.yolo.data.augment'] = importlib.import_module('ultralytics.data.augment') 10 | 11 | DATA_WARNING = """WARNING ⚠️ 'ultralytics.yolo.data' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.data' instead. 12 | Note this warning may be related to loading older models. You can update your model to current structure with: 13 | import torch 14 | ckpt = torch.load("model.pt") # applies to both official and custom models 15 | torch.save(ckpt, "updated-model.pt") 16 | """ 17 | LOGGER.warning(DATA_WARNING) 18 | -------------------------------------------------------------------------------- /ultralytics/yolo/engine/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | from ultralytics.utils import LOGGER 5 | 6 | # Set modules in sys.modules under their old name 7 | sys.modules['ultralytics.yolo.engine'] = importlib.import_module('ultralytics.engine') 8 | 9 | LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.engine' is deprecated since '8.0.136' and will be removed in '8.1.0'. " 10 | "Please use 'ultralytics.engine' instead.") 11 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | from ultralytics.utils import LOGGER 5 | 6 | # Set modules in sys.modules under their old name 7 | sys.modules['ultralytics.yolo.utils'] = importlib.import_module('ultralytics.utils') 8 | 9 | UTILS_WARNING = """WARNING ⚠️ 'ultralytics.yolo.utils' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.utils' instead. 10 | Note this warning may be related to loading older models. You can update your model to current structure with: 11 | import torch 12 | ckpt = torch.load("model.pt") # applies to both official and custom models 13 | torch.save(ckpt, "updated-model.pt") 14 | """ 15 | LOGGER.warning(UTILS_WARNING) 16 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | from ultralytics.utils import LOGGER 5 | 6 | # Set modules in sys.modules under their old name 7 | sys.modules['ultralytics.yolo.v8'] = importlib.import_module('ultralytics.models.yolo') 8 | 9 | LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.v8' is deprecated since '8.0.136' and will be removed in '8.1.0'. " 10 | "Please use 'ultralytics.models.yolo' instead.") 11 | -------------------------------------------------------------------------------- /y_prune/prune1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ultralytics import YOLO 4 | from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect 5 | import numpy as np 6 | 7 | def prune_conv(conv1,conv2,threshold): 8 | gamma=conv1.bn.weight.data.detach() 9 | beta=conv1.bn.bias.detach() 10 | keep_idxs=[] 11 | local_threshold = threshold 12 | while(len(keep_idxs)<8): 13 | keep_idx = torch.where(gamma.abs()>=local_threshold)[0] 14 | keep_idx = torch.ceil(torch.tensor(len(keep_idx)/8))*8 #为保证最后的通道式是8的倍数 15 | new_threashold = torch.sort(gamma.abs(),descending=True)[0][int(keep_idx-1)] 16 | keep_idxs = torch.where(gamma.abs()>=new_threashold)[0] #得到剪枝后通道的index 17 | local_threshold = new_threashold*0.5 18 | n = len(keep_idxs) 19 | print("prune rate for this layer: {%.2f} %",n/len(gamma) * 100) 20 | conv1.bn.weight.data = gamma[keep_idxs] #对conv1的bn层进行调整 21 | conv1.bn.bias.data = beta[keep_idxs] 22 | conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs] 23 | conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs] 24 | conv1.bn.num_features = n 25 | conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs] #对conv1的conv层进行调整 26 | conv1.conv.out_channels = n 27 | if conv1.conv.bias is not None: 28 | conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs] 29 | if not isinstance(conv2,list): #对conv2层进行处理,转换为list,此处统一为List是因为后续会传入list(1对多) 30 | conv2 = [conv2] 31 | for item in conv2: 32 | if item is not None: 33 | if isinstance(item,Conv): #找到conv2中的conv层并对其输入通道进行调整 34 | conv = item.conv 35 | else: 36 | conv = item 37 | conv.in_channels = n #conv2层的输入通道数量调整到与剪枝后conv1层的输出通道一致 38 | conv.weight.data = conv.weight.data[:,keep_idxs] 39 | 40 | def prune(m1,m2,threshold): 41 | if isinstance(m1,C2f): 42 | m1 = m1.cv2 43 | if not isinstance(m2,list): 44 | m2 = [m2] 45 | for i, item in enumerate(m2): 46 | if isinstance(item,C2f) or isinstance(item,SPPF): 47 | m2[i] = item.cv1 48 | prune_conv(m1,m2,threshold) 49 | 50 | if __name__ == "__main__": 51 | yolo = YOLO("weights/best.pt") # build a new model from scratch 52 | model = yolo.model 53 | ws = [] 54 | bs = [] 55 | for name, m in model.named_modules(): 56 | if isinstance(m, nn.BatchNorm2d): 57 | w = m.weight.abs().detach() 58 | b = m.bias.abs().detach() 59 | ws.append(w) 60 | bs.append(b) 61 | print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item()) 62 | factor = 0.8 #剪枝率暂时设置为0.8,推荐选用小的剪枝率进行多次剪枝操作 63 | ws = torch.cat(ws) 64 | threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)] 65 | 66 | seq = model.model 67 | #针对head部分,low level的先不要减,可能包含有重要的信息 68 | for i in range(3, 9): 69 | if i in [6, 4, 9]: 70 | continue 71 | prune(seq[i], seq[i+1],threshold) 72 | #对15,18,21层处理,15和18不仅和detect层相连,同时下个layer中也有conv层,但是21层是和Detect层直连 73 | for n,i in enumerate([15,18,21]): 74 | if(i!=21): 75 | prune(seq[i],[seq[i+1],seq[-1].cv2[n][0],seq[-1].cv3[n][0]],threshold) #C2f-conv,C2f-Detect 76 | prune_conv(seq[-1].cv2[n][0],seq[-1].cv2[n][1],threshold) #Detect.cv2 77 | prune_conv(seq[-1].cv2[n][1],seq[-1].cv2[n][2],threshold) #Detect.cv2 78 | prune_conv(seq[-1].cv3[n][0],seq[-1].cv3[n][1],threshold) #Detect.cv3 79 | prune_conv(seq[-1].cv3[n][1],seq[-1].cv3[n][2],threshold) #Detect.cv3 80 | #遍历模型并对bottleneck内的剪枝 81 | for name, m in model.named_modules(): 82 | if isinstance(m, Bottleneck): 83 | prune_conv(m.cv1, m.cv2,threshold) 84 | for name, p in yolo.model.named_parameters(): 85 | p.requires_grad = True 86 | yolo.train(data="ultralytics/cfg/datasets/VOC.yaml", epochs=300,workers=8,batch=32) 87 | print("done") 88 | -------------------------------------------------------------------------------- /y_prune/prune2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YellenYeason/yolov8-prune/b5bf7381ae3b88b848f2b525507b919c8a97f004/y_prune/prune2.py -------------------------------------------------------------------------------- /y_prune/y_train.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO # build a new model from scratch 2 | model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) 3 | model.train(data="ultralytics/cfg/datasets/VOC.yaml", epochs=3,batch=32,workers=8) --------------------------------------------------------------------------------