├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── README.zh-CN.md
├── mkdocs.yml
├── requirements.txt
├── setup.cfg
├── setup.py
├── ultralytics
├── __init__.py
├── cfg
│ ├── __init__.py
│ ├── default.yaml
│ ├── models
│ │ ├── README.md
│ │ ├── rt-detr
│ │ │ ├── rtdetr-l.yaml
│ │ │ └── rtdetr-x.yaml
│ │ ├── v3
│ │ │ ├── yolov3-spp.yaml
│ │ │ ├── yolov3-tiny.yaml
│ │ │ └── yolov3.yaml
│ │ ├── v5
│ │ │ ├── yolov5-p6.yaml
│ │ │ └── yolov5.yaml
│ │ ├── v6
│ │ │ └── yolov6.yaml
│ │ └── v8
│ │ │ ├── yolov8-cls.yaml
│ │ │ ├── yolov8-p2.yaml
│ │ │ ├── yolov8-p6.yaml
│ │ │ ├── yolov8-pose-p6.yaml
│ │ │ ├── yolov8-pose.yaml
│ │ │ ├── yolov8-rtdetr.yaml
│ │ │ ├── yolov8-seg.yaml
│ │ │ └── yolov8.yaml
│ └── trackers
│ │ ├── botsort.yaml
│ │ └── bytetrack.yaml
├── data
│ ├── __init__.py
│ ├── annotator.py
│ ├── augment.py
│ ├── base.py
│ ├── build.py
│ ├── converter.py
│ ├── dataset.py
│ ├── loaders.py
│ ├── scripts
│ │ ├── download_weights.sh
│ │ ├── get_coco.sh
│ │ ├── get_coco128.sh
│ │ └── get_imagenet.sh
│ └── utils.py
├── engine
│ ├── __init__.py
│ ├── exporter.py
│ ├── model.py
│ ├── predictor.py
│ ├── results.py
│ ├── trainer.py
│ └── validator.py
├── hub
│ ├── __init__.py
│ ├── auth.py
│ ├── session.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── fastsam
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── predict.py
│ │ ├── prompt.py
│ │ ├── utils.py
│ │ └── val.py
│ ├── nas
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── predict.py
│ │ └── val.py
│ ├── rtdetr
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── predict.py
│ │ ├── train.py
│ │ └── val.py
│ ├── sam
│ │ ├── __init__.py
│ │ ├── amg.py
│ │ ├── build.py
│ │ ├── model.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── decoders.py
│ │ │ ├── encoders.py
│ │ │ ├── sam.py
│ │ │ ├── tiny_encoder.py
│ │ │ └── transformer.py
│ │ └── predict.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ └── ops.py
│ └── yolo
│ │ ├── __init__.py
│ │ ├── classify
│ │ ├── __init__.py
│ │ ├── predict.py
│ │ ├── train.py
│ │ └── val.py
│ │ ├── detect
│ │ ├── __init__.py
│ │ ├── predict.py
│ │ ├── train.py
│ │ └── val.py
│ │ ├── model.py
│ │ ├── pose
│ │ ├── __init__.py
│ │ ├── predict.py
│ │ ├── train.py
│ │ └── val.py
│ │ └── segment
│ │ ├── __init__.py
│ │ ├── predict.py
│ │ ├── train.py
│ │ └── val.py
├── nn
│ ├── __init__.py
│ ├── autobackend.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── conv.py
│ │ ├── head.py
│ │ ├── transformer.py
│ │ └── utils.py
│ └── tasks.py
├── trackers
│ ├── README.md
│ ├── __init__.py
│ ├── basetrack.py
│ ├── bot_sort.py
│ ├── byte_tracker.py
│ ├── track.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── gmc.py
│ │ ├── kalman_filter.py
│ │ └── matching.py
├── utils
│ ├── __init__.py
│ ├── autobatch.py
│ ├── benchmarks.py
│ ├── callbacks
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── clearml.py
│ │ ├── comet.py
│ │ ├── dvc.py
│ │ ├── hub.py
│ │ ├── mlflow.py
│ │ ├── neptune.py
│ │ ├── raytune.py
│ │ ├── tensorboard.py
│ │ └── wb.py
│ ├── checks.py
│ ├── dist.py
│ ├── downloads.py
│ ├── errors.py
│ ├── files.py
│ ├── instance.py
│ ├── loss.py
│ ├── metrics.py
│ ├── ops.py
│ ├── patches.py
│ ├── plotting.py
│ ├── tal.py
│ ├── torch_utils.py
│ └── tuner.py
└── yolo
│ ├── __init__.py
│ ├── cfg
│ └── __init__.py
│ ├── data
│ └── __init__.py
│ ├── engine
│ └── __init__.py
│ ├── utils
│ └── __init__.py
│ └── v8
│ └── __init__.py
└── y_prune
├── prune1.py
├── prune2.py
└── y_train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # Profiling
85 | *.pclprof
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | .idea
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 | mkdocs_github_authors.yaml
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | # datasets and projects
137 | datasets/
138 | runs/
139 | wandb/
140 | .DS_Store
141 |
142 | # Neural Network weights -----------------------------------------------------------------------------------------------
143 | weights/
144 | *.weights
145 | *.pt
146 | *.pb
147 | *.onnx
148 | *.engine
149 | *.mlmodel
150 | *.mlpackage
151 | *.torchscript
152 | *.tflite
153 | *.h5
154 | *_saved_model/
155 | *_web_model/
156 | *_openvino_model/
157 | *_paddle_model/
158 |
159 | # Autogenerated files for tests
160 | /ultralytics/assets/
161 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
3 |
4 | exclude: 'docs/'
5 | # Define bot property if installed via https://github.com/marketplace/pre-commit-ci
6 | ci:
7 | autofix_prs: true
8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
9 | autoupdate_schedule: monthly
10 | # submodules: true
11 |
12 | repos:
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v4.4.0
15 | hooks:
16 | - id: end-of-file-fixer
17 | - id: trailing-whitespace
18 | - id: check-case-conflict
19 | # - id: check-yaml
20 | - id: check-docstring-first
21 | - id: double-quote-string-fixer
22 | - id: detect-private-key
23 |
24 | - repo: https://github.com/asottile/pyupgrade
25 | rev: v3.10.1
26 | hooks:
27 | - id: pyupgrade
28 | name: Upgrade code
29 |
30 | - repo: https://github.com/PyCQA/isort
31 | rev: 5.12.0
32 | hooks:
33 | - id: isort
34 | name: Sort imports
35 |
36 | - repo: https://github.com/google/yapf
37 | rev: v0.40.0
38 | hooks:
39 | - id: yapf
40 | name: YAPF formatting
41 |
42 | - repo: https://github.com/executablebooks/mdformat
43 | rev: 0.7.16
44 | hooks:
45 | - id: mdformat
46 | name: MD formatting
47 | additional_dependencies:
48 | - mdformat-gfm
49 | - mdformat-black
50 | # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
51 |
52 | - repo: https://github.com/PyCQA/flake8
53 | rev: 6.1.0
54 | hooks:
55 | - id: flake8
56 | name: PEP8
57 |
58 | - repo: https://github.com/codespell-project/codespell
59 | rev: v2.2.5
60 | hooks:
61 | - id: codespell
62 | args:
63 | - --ignore-words-list=crate,nd,strack,dota,ane,segway,fo
64 |
65 | # - repo: https://github.com/asottile/yesqa
66 | # rev: v1.4.0
67 | # hooks:
68 | # - id: yesqa
69 |
70 | # - repo: https://github.com/asottile/dead
71 | # rev: v1.5.0
72 | # hooks:
73 | # - id: dead
74 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | preferred-citation:
3 | type: software
4 | message: If you use this software, please cite it as below.
5 | authors:
6 | - family-names: Jocher
7 | given-names: Glenn
8 | orcid: "https://orcid.org/0000-0001-5950-6979"
9 | - family-names: Chaurasia
10 | given-names: Ayush
11 | orcid: "https://orcid.org/0000-0002-7603-6750"
12 | - family-names: Qiu
13 | given-names: Jing
14 | orcid: "https://orcid.org/0000-0003-3783-7069"
15 | title: "YOLO by Ultralytics"
16 | version: 8.0.0
17 | # doi: 10.5281/zenodo.3908559 # TODO
18 | date-released: 2023-1-10
19 | license: AGPL-3.0
20 | url: "https://github.com/ultralytics/ultralytics"
21 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing to YOLOv8 🚀
2 |
3 | We love your input! We want to make contributing to YOLOv8 as easy and transparent as possible, whether it's:
4 |
5 | - Reporting a bug
6 | - Discussing the current state of the code
7 | - Submitting a fix
8 | - Proposing a new feature
9 | - Becoming a maintainer
10 |
11 | YOLOv8 works so well due to our combined community effort, and for every small improvement you contribute you will be
12 | helping push the frontiers of what's possible in AI 😃!
13 |
14 | ## Submitting a Pull Request (PR) 🛠️
15 |
16 | Submitting a PR is easy! This example shows how to submit a PR for updating `requirements.txt` in 4 steps:
17 |
18 | ### 1. Select File to Update
19 |
20 | Select `requirements.txt` to update by clicking on it in GitHub.
21 |
22 |

23 |
24 | ### 2. Click 'Edit this file'
25 |
26 | Button is in top-right corner.
27 |
28 | 
29 |
30 | ### 3. Make Changes
31 |
32 | Change `matplotlib` version from `3.2.2` to `3.3`.
33 |
34 | 
35 |
36 | ### 4. Preview Changes and Submit PR
37 |
38 | Click on the **Preview changes** tab to verify your updates. At the bottom of the screen select 'Create a **new branch**
39 | for this commit', assign your branch a descriptive name such as `fix/matplotlib_version` and click the green **Propose
40 | changes** button. All done, your PR is now submitted to YOLOv8 for review and approval 😃!
41 |
42 | 
43 |
44 | ### PR recommendations
45 |
46 | To allow your work to be integrated as seamlessly as possible, we advise you to:
47 |
48 | - ✅ Verify your PR is **up-to-date** with `ultralytics/ultralytics` `main` branch. If your PR is behind you can update
49 | your code by clicking the 'Update branch' button or by running `git pull` and `git merge main` locally.
50 |
51 | 
52 |
53 | - ✅ Verify all YOLOv8 Continuous Integration (CI) **checks are passing**.
54 |
55 | 
56 |
57 | - ✅ Reduce changes to the absolute **minimum** required for your bug fix or feature addition. _"It is not daily increase
58 | but daily decrease, hack away the unessential. The closer to the source, the less wastage there is."_ — Bruce Lee
59 |
60 | ### Docstrings
61 |
62 | Not all functions or classes require docstrings but when they do, we
63 | follow [google-style docstrings format](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings).
64 | Here is an example:
65 |
66 | ```python
67 | """
68 | What the function does. Performs NMS on given detection predictions.
69 |
70 | Args:
71 | arg1: The description of the 1st argument
72 | arg2: The description of the 2nd argument
73 |
74 | Returns:
75 | What the function returns. Empty if nothing is returned.
76 |
77 | Raises:
78 | Exception Class: When and why this exception can be raised by the function.
79 | """
80 | ```
81 |
82 | ## Submitting a Bug Report 🐛
83 |
84 | If you spot a problem with YOLOv8 please submit a Bug Report!
85 |
86 | For us to start investigating a possible problem we need to be able to reproduce it ourselves first. We've created a few
87 | short guidelines below to help users provide what we need in order to get started.
88 |
89 | When asking a question, people will be better able to provide help if you provide **code** that they can easily
90 | understand and use to **reproduce** the problem. This is referred to by community members as creating
91 | a [minimum reproducible example](https://docs.ultralytics.com/help/minimum_reproducible_example/). Your code that reproduces
92 | the problem should be:
93 |
94 | - ✅ **Minimal** – Use as little code as possible that still produces the same problem
95 | - ✅ **Complete** – Provide **all** parts someone else needs to reproduce your problem in the question itself
96 | - ✅ **Reproducible** – Test the code you're about to provide to make sure it reproduces the problem
97 |
98 | In addition to the above requirements, for [Ultralytics](https://ultralytics.com/) to provide assistance your code
99 | should be:
100 |
101 | - ✅ **Current** – Verify that your code is up-to-date with current
102 | GitHub [main](https://github.com/ultralytics/ultralytics/tree/main) branch, and if necessary `git pull` or `git clone`
103 | a new copy to ensure your problem has not already been resolved by previous commits.
104 | - ✅ **Unmodified** – Your problem must be reproducible without any modifications to the codebase in this
105 | repository. [Ultralytics](https://ultralytics.com/) does not provide support for custom code ⚠️.
106 |
107 | If you believe your problem meets all of the above criteria, please close this issue and raise a new one using the 🐛
108 | **Bug Report** [template](https://github.com/ultralytics/ultralytics/issues/new/choose) and providing
109 | a [minimum reproducible example](https://docs.ultralytics.com/help/minimum_reproducible_example/) to help us better
110 | understand and diagnose your problem.
111 |
112 | ## License
113 |
114 | By contributing, you agree that your contributions will be licensed under
115 | the [AGPL-3.0 license](https://choosealicense.com/licenses/agpl-3.0/)
116 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.md
2 | include requirements.txt
3 | include LICENSE
4 | include setup.py
5 | include ultralytics/assets/bus.jpg
6 | include ultralytics/assets/zidane.jpg
7 | recursive-include ultralytics *.yaml
8 | recursive-exclude __pycache__ *
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Ultralytics requirements
2 | # Usage: pip install -r requirements.txt
3 |
4 | # Base ----------------------------------------
5 | matplotlib>=3.2.2
6 | numpy>=1.22.2 # pinned by Snyk to avoid a vulnerability
7 | opencv-python>=4.6.0
8 | pillow>=7.1.2
9 | pyyaml>=5.3.1
10 | requests>=2.23.0
11 | scipy>=1.4.1
12 | torch>=1.8.0
13 | torchvision>=0.9.0
14 | tqdm>=4.64.0
15 |
16 | # Logging -------------------------------------
17 | # tensorboard>=2.13.0
18 | # dvclive>=2.12.0
19 | # clearml
20 | # comet
21 |
22 | # Plotting ------------------------------------
23 | pandas>=1.1.4
24 | seaborn>=0.11.0
25 |
26 | # Export --------------------------------------
27 | # coremltools>=7.0.b1 # CoreML export
28 | # onnx>=1.12.0 # ONNX export
29 | # onnxsim>=0.4.1 # ONNX simplifier
30 | # nvidia-pyindex # TensorRT export
31 | # nvidia-tensorrt # TensorRT export
32 | # scikit-learn==0.19.2 # CoreML quantization
33 | # tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
34 | # tflite-support
35 | # tensorflowjs>=3.9.0 # TF.js export
36 | # openvino-dev>=2023.0 # OpenVINO export
37 |
38 | # Extras --------------------------------------
39 | psutil # system utilization
40 | py-cpuinfo # display CPU info
41 | # thop>=0.1.1 # FLOPs computation
42 | # ipython # interactive notebook
43 | # albumentations>=1.0.3 # training augmentations
44 | # pycocotools>=2.0.6 # COCO mAP
45 | # roboflow
46 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | # Project-wide configuration file, can be used for package metadata and other toll configurations
2 | # Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments
3 | # Local usage: pip install pre-commit, pre-commit run --all-files
4 |
5 | [metadata]
6 | license_files = LICENSE
7 | description_file = README.md
8 |
9 | [tool:pytest]
10 | norecursedirs =
11 | .git
12 | dist
13 | build
14 | addopts =
15 | --doctest-modules
16 | --durations=25
17 | --color=yes
18 |
19 | [coverage:run]
20 | source = ultralytics/
21 | data_file = tests/.coverage
22 | omit =
23 | ultralytics/utils/callbacks/*
24 |
25 | [flake8]
26 | max-line-length = 120
27 | exclude = .tox,*.egg,build,temp
28 | select = E,W,F
29 | doctests = True
30 | verbose = 2
31 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
32 | format = pylint
33 | # see: https://www.flake8rules.com/
34 | ignore = E731,F405,E402,W504,E501
35 | # E731: Do not assign a lambda expression, use a def
36 | # F405: name may be undefined, or defined from star imports: module
37 | # E402: module level import not at top of file
38 | # W504: line break after binary operator
39 | # E501: line too long
40 | # removed:
41 | # F401: module imported but unused
42 | # E231: missing whitespace after ‘,’, ‘;’, or ‘:’
43 | # E127: continuation line over-indented for visual indent
44 | # F403: ‘from module import *’ used; unable to detect undefined names
45 |
46 |
47 | [isort]
48 | # https://pycqa.github.io/isort/docs/configuration/options.html
49 | line_length = 120
50 | # see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html
51 | multi_line_output = 0
52 |
53 | [yapf]
54 | based_on_style = pep8
55 | spaces_before_comment = 2
56 | COLUMN_LIMIT = 120
57 | COALESCE_BRACKETS = True
58 | SPACES_AROUND_POWER_OPERATOR = True
59 | SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True
60 | SPLIT_BEFORE_CLOSING_BRACKET = False
61 | SPLIT_BEFORE_FIRST_ARGUMENT = False
62 | # EACH_DICT_ENTRY_ON_SEPARATE_LINE = False
63 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import re
4 | from pathlib import Path
5 |
6 | import pkg_resources as pkg
7 | from setuptools import find_packages, setup
8 |
9 | # Settings
10 | FILE = Path(__file__).resolve()
11 | PARENT = FILE.parent # root directory
12 | README = (PARENT / 'README.md').read_text(encoding='utf-8')
13 | REQUIREMENTS = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements((PARENT / 'requirements.txt').read_text())]
14 |
15 |
16 | def get_version():
17 | file = PARENT / 'ultralytics/__init__.py'
18 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding='utf-8'), re.M)[1]
19 |
20 |
21 | setup(
22 | name='ultralytics', # name of pypi package
23 | version=get_version(), # version of pypi package
24 | python_requires='>=3.8',
25 | license='AGPL-3.0',
26 | description=('Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, '
27 | 'pose estimation and image classification.'),
28 | long_description=README,
29 | long_description_content_type='text/markdown',
30 | url='https://github.com/ultralytics/ultralytics',
31 | project_urls={
32 | 'Bug Reports': 'https://github.com/ultralytics/ultralytics/issues',
33 | 'Funding': 'https://ultralytics.com',
34 | 'Source': 'https://github.com/ultralytics/ultralytics'},
35 | author='Ultralytics',
36 | author_email='hello@ultralytics.com',
37 | packages=find_packages(), # required
38 | include_package_data=True,
39 | install_requires=REQUIREMENTS,
40 | extras_require={
41 | 'dev': [
42 | 'ipython',
43 | 'check-manifest',
44 | 'pytest',
45 | 'pytest-cov',
46 | 'coverage',
47 | 'mkdocs-material',
48 | 'mkdocstrings[python]',
49 | 'mkdocs-redirects', # for 301 redirects
50 | 'mkdocs-ultralytics-plugin>=0.0.25', # for meta descriptions and images, dates and authors
51 | ],
52 | 'export': [
53 | 'coremltools>=7.0.b1',
54 | 'openvino-dev>=2023.0',
55 | 'tensorflowjs', # automatically installs tensorflow
56 | ], },
57 | classifiers=[
58 | 'Development Status :: 4 - Beta',
59 | 'Intended Audience :: Developers',
60 | 'Intended Audience :: Education',
61 | 'Intended Audience :: Science/Research',
62 | 'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)',
63 | 'Programming Language :: Python :: 3',
64 | 'Programming Language :: Python :: 3.8',
65 | 'Programming Language :: Python :: 3.9',
66 | 'Programming Language :: Python :: 3.10',
67 | 'Programming Language :: Python :: 3.11',
68 | 'Topic :: Software Development',
69 | 'Topic :: Scientific/Engineering',
70 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
71 | 'Topic :: Scientific/Engineering :: Image Recognition',
72 | 'Operating System :: POSIX :: Linux',
73 | 'Operating System :: MacOS',
74 | 'Operating System :: Microsoft :: Windows', ],
75 | keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics',
76 | entry_points={'console_scripts': ['yolo = ultralytics.cfg:entrypoint', 'ultralytics = ultralytics.cfg:entrypoint']})
77 |
--------------------------------------------------------------------------------
/ultralytics/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | __version__ = '8.0.155'
4 |
5 | from ultralytics.hub import start
6 | from ultralytics.models import RTDETR, SAM, YOLO
7 | from ultralytics.models.fastsam import FastSAM
8 | from ultralytics.models.nas import NAS
9 | from ultralytics.utils import SETTINGS as settings
10 | from ultralytics.utils.checks import check_yolo as checks
11 | from ultralytics.utils.downloads import download
12 |
13 | __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
14 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/README.md:
--------------------------------------------------------------------------------
1 | ## Models
2 |
3 | Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
4 | files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
5 | and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
6 | segmentation tasks.
7 |
8 | These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
9 | instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
10 | from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
11 | directory provides a great starting point for your custom model development needs.
12 |
13 | To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
14 | selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
15 | details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
16 | to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
17 |
18 | ### Usage
19 |
20 | Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
21 |
22 | ```bash
23 | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
24 | ```
25 |
26 | They may also be used directly in a Python environment, and accepts the same
27 | [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
28 |
29 | ```python
30 | from ultralytics import YOLO
31 |
32 | model = YOLO("model.yaml") # build a YOLOv8n model from scratch
33 | # YOLO("model.pt") use pre-trained model if available
34 | model.info() # display model information
35 | model.train(data="coco128.yaml", epochs=100) # train the model
36 | ```
37 |
38 | ## Pre-trained Model Architectures
39 |
40 | Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information
41 | and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
42 |
43 | ## Contributing New Models
44 |
45 | If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
46 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | l: [1.00, 1.00, 1024]
9 |
10 | backbone:
11 | # [from, repeats, module, args]
12 | - [-1, 1, HGStem, [32, 48]] # 0-P2/4
13 | - [-1, 6, HGBlock, [48, 128, 3]] # stage 1
14 |
15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
16 | - [-1, 6, HGBlock, [96, 512, 3]] # stage 2
17 |
18 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
19 | - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
20 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
21 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
22 |
23 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
24 | - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
25 |
26 | head:
27 | - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
28 | - [-1, 1, AIFI, [1024, 8]]
29 | - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
30 |
31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
32 | - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
33 | - [[-2, -1], 1, Concat, [1]]
34 | - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
35 | - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
36 |
37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
38 | - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
39 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4
40 | - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
41 |
42 | - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
43 | - [[-1, 17], 1, Concat, [1]] # cat Y4
44 | - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
45 |
46 | - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
47 | - [[-1, 12], 1, Concat, [1]] # cat Y5
48 | - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
49 |
50 | - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
51 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | x: [1.00, 1.00, 2048]
9 |
10 | backbone:
11 | # [from, repeats, module, args]
12 | - [-1, 1, HGStem, [32, 64]] # 0-P2/4
13 | - [-1, 6, HGBlock, [64, 128, 3]] # stage 1
14 |
15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
16 | - [-1, 6, HGBlock, [128, 512, 3]]
17 | - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2
18 |
19 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16
20 | - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut
21 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
22 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
23 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
24 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3
25 |
26 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32
27 | - [-1, 6, HGBlock, [512, 2048, 5, True, False]]
28 | - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4
29 |
30 | head:
31 | - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2
32 | - [-1, 1, AIFI, [2048, 8]]
33 | - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0
34 |
35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36 | - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1
37 | - [[-2, -1], 1, Concat, [1]]
38 | - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0
39 | - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1
40 |
41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
42 | - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0
43 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4
44 | - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1
45 |
46 | - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0
47 | - [[-1, 21], 1, Concat, [1]] # cat Y4
48 | - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0
49 |
50 | - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1
51 | - [[-1, 16], 1, Concat, [1]] # cat Y5
52 | - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1
53 |
54 | - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
55 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v3/yolov3-spp.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | depth_multiple: 1.0 # model depth multiple
7 | width_multiple: 1.0 # layer channel multiple
8 |
9 | # darknet53 backbone
10 | backbone:
11 | # [from, number, module, args]
12 | [[-1, 1, Conv, [32, 3, 1]], # 0
13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14 | [-1, 1, Bottleneck, [64]],
15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
16 | [-1, 2, Bottleneck, [128]],
17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
18 | [-1, 8, Bottleneck, [256]],
19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
20 | [-1, 8, Bottleneck, [512]],
21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
22 | [-1, 4, Bottleneck, [1024]], # 10
23 | ]
24 |
25 | # YOLOv3-SPP head
26 | head:
27 | [[-1, 1, Bottleneck, [1024, False]],
28 | [-1, 1, SPP, [512, [5, 9, 13]]],
29 | [-1, 1, Conv, [1024, 3, 1]],
30 | [-1, 1, Conv, [512, 1, 1]],
31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
32 |
33 | [-2, 1, Conv, [256, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4
36 | [-1, 1, Bottleneck, [512, False]],
37 | [-1, 1, Bottleneck, [512, False]],
38 | [-1, 1, Conv, [256, 1, 1]],
39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
40 |
41 | [-2, 1, Conv, [128, 1, 1]],
42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3
44 | [-1, 1, Bottleneck, [256, False]],
45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
46 |
47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
48 | ]
49 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v3/yolov3-tiny.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | depth_multiple: 1.0 # model depth multiple
7 | width_multiple: 1.0 # layer channel multiple
8 |
9 | # YOLOv3-tiny backbone
10 | backbone:
11 | # [from, number, module, args]
12 | [[-1, 1, Conv, [16, 3, 1]], # 0
13 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
14 | [-1, 1, Conv, [32, 3, 1]],
15 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
16 | [-1, 1, Conv, [64, 3, 1]],
17 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
18 | [-1, 1, Conv, [128, 3, 1]],
19 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
20 | [-1, 1, Conv, [256, 3, 1]],
21 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
22 | [-1, 1, Conv, [512, 3, 1]],
23 | [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
24 | [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
25 | ]
26 |
27 | # YOLOv3-tiny head
28 | head:
29 | [[-1, 1, Conv, [1024, 3, 1]],
30 | [-1, 1, Conv, [256, 1, 1]],
31 | [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
32 |
33 | [-2, 1, Conv, [128, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4
36 | [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
37 |
38 | [[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
39 | ]
40 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v3/yolov3.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | depth_multiple: 1.0 # model depth multiple
7 | width_multiple: 1.0 # layer channel multiple
8 |
9 | # darknet53 backbone
10 | backbone:
11 | # [from, number, module, args]
12 | [[-1, 1, Conv, [32, 3, 1]], # 0
13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14 | [-1, 1, Bottleneck, [64]],
15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
16 | [-1, 2, Bottleneck, [128]],
17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
18 | [-1, 8, Bottleneck, [256]],
19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
20 | [-1, 8, Bottleneck, [512]],
21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
22 | [-1, 4, Bottleneck, [1024]], # 10
23 | ]
24 |
25 | # YOLOv3 head
26 | head:
27 | [[-1, 1, Bottleneck, [1024, False]],
28 | [-1, 1, Conv, [512, 1, 1]],
29 | [-1, 1, Conv, [1024, 3, 1]],
30 | [-1, 1, Conv, [512, 1, 1]],
31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
32 |
33 | [-2, 1, Conv, [256, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4
36 | [-1, 1, Bottleneck, [512, False]],
37 | [-1, 1, Bottleneck, [512, False]],
38 | [-1, 1, Conv, [256, 1, 1]],
39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
40 |
41 | [-2, 1, Conv, [128, 1, 1]],
42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3
44 | [-1, 1, Bottleneck, [256, False]],
45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
46 |
47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
48 | ]
49 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v5/yolov5-p6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 1024]
11 | l: [1.00, 1.00, 1024]
12 | x: [1.33, 1.25, 1024]
13 |
14 | # YOLOv5 v6.0 backbone
15 | backbone:
16 | # [from, number, module, args]
17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
19 | [-1, 3, C3, [128]],
20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
21 | [-1, 6, C3, [256]],
22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
23 | [-1, 9, C3, [512]],
24 | [-1, 1, Conv, [768, 3, 2]], # 7-P5/32
25 | [-1, 3, C3, [768]],
26 | [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
27 | [-1, 3, C3, [1024]],
28 | [-1, 1, SPPF, [1024, 5]], # 11
29 | ]
30 |
31 | # YOLOv5 v6.0 head
32 | head:
33 | [[-1, 1, Conv, [768, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P5
36 | [-1, 3, C3, [768, False]], # 15
37 |
38 | [-1, 1, Conv, [512, 1, 1]],
39 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
40 | [[-1, 6], 1, Concat, [1]], # cat backbone P4
41 | [-1, 3, C3, [512, False]], # 19
42 |
43 | [-1, 1, Conv, [256, 1, 1]],
44 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
45 | [[-1, 4], 1, Concat, [1]], # cat backbone P3
46 | [-1, 3, C3, [256, False]], # 23 (P3/8-small)
47 |
48 | [-1, 1, Conv, [256, 3, 2]],
49 | [[-1, 20], 1, Concat, [1]], # cat head P4
50 | [-1, 3, C3, [512, False]], # 26 (P4/16-medium)
51 |
52 | [-1, 1, Conv, [512, 3, 2]],
53 | [[-1, 16], 1, Concat, [1]], # cat head P5
54 | [-1, 3, C3, [768, False]], # 29 (P5/32-large)
55 |
56 | [-1, 1, Conv, [768, 3, 2]],
57 | [[-1, 12], 1, Concat, [1]], # cat head P6
58 | [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
59 |
60 | [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
61 | ]
62 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v5/yolov5.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 1024]
11 | l: [1.00, 1.00, 1024]
12 | x: [1.33, 1.25, 1024]
13 |
14 | # YOLOv5 v6.0 backbone
15 | backbone:
16 | # [from, number, module, args]
17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
19 | [-1, 3, C3, [128]],
20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
21 | [-1, 6, C3, [256]],
22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
23 | [-1, 9, C3, [512]],
24 | [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
25 | [-1, 3, C3, [1024]],
26 | [-1, 1, SPPF, [1024, 5]], # 9
27 | ]
28 |
29 | # YOLOv5 v6.0 head
30 | head:
31 | [[-1, 1, Conv, [512, 1, 1]],
32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
33 | [[-1, 6], 1, Concat, [1]], # cat backbone P4
34 | [-1, 3, C3, [512, False]], # 13
35 |
36 | [-1, 1, Conv, [256, 1, 1]],
37 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38 | [[-1, 4], 1, Concat, [1]], # cat backbone P3
39 | [-1, 3, C3, [256, False]], # 17 (P3/8-small)
40 |
41 | [-1, 1, Conv, [256, 3, 2]],
42 | [[-1, 14], 1, Concat, [1]], # cat head P4
43 | [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
44 |
45 | [-1, 1, Conv, [512, 3, 2]],
46 | [[-1, 10], 1, Concat, [1]], # cat head P5
47 | [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
48 |
49 | [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
50 | ]
51 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v6/yolov6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | activation: nn.ReLU() # (optional) model default activation function
7 | scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
8 | # [depth, width, max_channels]
9 | n: [0.33, 0.25, 1024]
10 | s: [0.33, 0.50, 1024]
11 | m: [0.67, 0.75, 768]
12 | l: [1.00, 1.00, 512]
13 | x: [1.00, 1.25, 512]
14 |
15 | # YOLOv6-3.0s backbone
16 | backbone:
17 | # [from, repeats, module, args]
18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20 | - [-1, 6, Conv, [128, 3, 1]]
21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22 | - [-1, 12, Conv, [256, 3, 1]]
23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24 | - [-1, 18, Conv, [512, 3, 1]]
25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26 | - [-1, 6, Conv, [1024, 3, 1]]
27 | - [-1, 1, SPPF, [1024, 5]] # 9
28 |
29 | # YOLOv6-3.0s head
30 | head:
31 | - [-1, 1, Conv, [256, 1, 1]]
32 | - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
33 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
34 | - [-1, 1, Conv, [256, 3, 1]]
35 | - [-1, 9, Conv, [256, 3, 1]] # 14
36 |
37 | - [-1, 1, Conv, [128, 1, 1]]
38 | - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
39 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
40 | - [-1, 1, Conv, [128, 3, 1]]
41 | - [-1, 9, Conv, [128, 3, 1]] # 19
42 |
43 | - [-1, 1, Conv, [128, 3, 2]]
44 | - [[-1, 15], 1, Concat, [1]] # cat head P4
45 | - [-1, 1, Conv, [256, 3, 1]]
46 | - [-1, 9, Conv, [256, 3, 1]] # 23
47 |
48 | - [-1, 1, Conv, [256, 3, 2]]
49 | - [[-1, 10], 1, Concat, [1]] # cat head P5
50 | - [-1, 1, Conv, [512, 3, 1]]
51 | - [-1, 9, Conv, [512, 3, 1]] # 27
52 |
53 | - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)
54 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-cls.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
3 |
4 | # Parameters
5 | nc: 1000 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 1024]
11 | l: [1.00, 1.00, 1024]
12 | x: [1.00, 1.25, 1024]
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 |
27 | # YOLOv8.0n head
28 | head:
29 | - [-1, 1, Classify, [nc]] # Classify
30 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-p2.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 768]
11 | l: [1.00, 1.00, 512]
12 | x: [1.00, 1.25, 512]
13 |
14 | # YOLOv8.0 backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0-p2 head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
39 | - [[-1, 2], 1, Concat, [1]] # cat backbone P2
40 | - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall)
41 |
42 | - [-1, 1, Conv, [128, 3, 2]]
43 | - [[-1, 15], 1, Concat, [1]] # cat head P3
44 | - [-1, 3, C2f, [256]] # 21 (P3/8-small)
45 |
46 | - [-1, 1, Conv, [256, 3, 2]]
47 | - [[-1, 12], 1, Concat, [1]] # cat head P4
48 | - [-1, 3, C2f, [512]] # 24 (P4/16-medium)
49 |
50 | - [-1, 1, Conv, [512, 3, 2]]
51 | - [[-1, 9], 1, Concat, [1]] # cat head P5
52 | - [-1, 3, C2f, [1024]] # 27 (P5/32-large)
53 |
54 | - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5)
55 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-p6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 768]
11 | l: [1.00, 1.00, 512]
12 | x: [1.00, 1.25, 512]
13 |
14 | # YOLOv8.0x6 backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [768, True]]
26 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
27 | - [-1, 3, C2f, [1024, True]]
28 | - [-1, 1, SPPF, [1024, 5]] # 11
29 |
30 | # YOLOv8.0x6 head
31 | head:
32 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
33 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5
34 | - [-1, 3, C2, [768, False]] # 14
35 |
36 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
37 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
38 | - [-1, 3, C2, [512, False]] # 17
39 |
40 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
41 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
42 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
43 |
44 | - [-1, 1, Conv, [256, 3, 2]]
45 | - [[-1, 17], 1, Concat, [1]] # cat head P4
46 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
47 |
48 | - [-1, 1, Conv, [512, 3, 2]]
49 | - [[-1, 14], 1, Concat, [1]] # cat head P5
50 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large)
51 |
52 | - [-1, 1, Conv, [768, 3, 2]]
53 | - [[-1, 11], 1, Concat, [1]] # cat head P6
54 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
55 |
56 | - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)
57 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
3 |
4 | # Parameters
5 | nc: 1 # number of classes
6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
8 | # [depth, width, max_channels]
9 | n: [0.33, 0.25, 1024]
10 | s: [0.33, 0.50, 1024]
11 | m: [0.67, 0.75, 768]
12 | l: [1.00, 1.00, 512]
13 | x: [1.00, 1.25, 512]
14 |
15 | # YOLOv8.0x6 backbone
16 | backbone:
17 | # [from, repeats, module, args]
18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20 | - [-1, 3, C2f, [128, True]]
21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22 | - [-1, 6, C2f, [256, True]]
23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24 | - [-1, 6, C2f, [512, True]]
25 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
26 | - [-1, 3, C2f, [768, True]]
27 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
28 | - [-1, 3, C2f, [1024, True]]
29 | - [-1, 1, SPPF, [1024, 5]] # 11
30 |
31 | # YOLOv8.0x6 head
32 | head:
33 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
34 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5
35 | - [-1, 3, C2, [768, False]] # 14
36 |
37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
38 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
39 | - [-1, 3, C2, [512, False]] # 17
40 |
41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
42 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
43 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
44 |
45 | - [-1, 1, Conv, [256, 3, 2]]
46 | - [[-1, 17], 1, Concat, [1]] # cat head P4
47 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
48 |
49 | - [-1, 1, Conv, [512, 3, 2]]
50 | - [[-1, 14], 1, Concat, [1]] # cat head P5
51 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large)
52 |
53 | - [-1, 1, Conv, [768, 3, 2]]
54 | - [[-1, 11], 1, Concat, [1]] # cat head P6
55 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
56 |
57 | - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6)
58 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-pose.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
3 |
4 | # Parameters
5 | nc: 1 # number of classes
6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
8 | # [depth, width, max_channels]
9 | n: [0.33, 0.25, 1024]
10 | s: [0.33, 0.50, 1024]
11 | m: [0.67, 0.75, 768]
12 | l: [1.00, 1.00, 512]
13 | x: [1.00, 1.25, 512]
14 |
15 | # YOLOv8.0n backbone
16 | backbone:
17 | # [from, repeats, module, args]
18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20 | - [-1, 3, C2f, [128, True]]
21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22 | - [-1, 6, C2f, [256, True]]
23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24 | - [-1, 6, C2f, [512, True]]
25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26 | - [-1, 3, C2f, [1024, True]]
27 | - [-1, 1, SPPF, [1024, 5]] # 9
28 |
29 | # YOLOv8.0n head
30 | head:
31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
32 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
33 | - [-1, 3, C2f, [512]] # 12
34 |
35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
38 |
39 | - [-1, 1, Conv, [256, 3, 2]]
40 | - [[-1, 12], 1, Concat, [1]] # cat head P4
41 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
42 |
43 | - [-1, 1, Conv, [512, 3, 2]]
44 | - [[-1, 9], 1, Concat, [1]] # cat head P5
45 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
46 |
47 | - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)
48 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0n head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, Conv, [256, 3, 2]]
39 | - [[-1, 12], 1, Concat, [1]] # cat head P4
40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41 |
42 | - [-1, 1, Conv, [512, 3, 2]]
43 | - [[-1, 9], 1, Concat, [1]] # cat head P5
44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45 |
46 | - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
47 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8-seg.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 768]
11 | l: [1.00, 1.00, 512]
12 | x: [1.00, 1.25, 512]
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0n head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, Conv, [256, 3, 2]]
39 | - [[-1, 12], 1, Concat, [1]] # cat head P4
40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41 |
42 | - [-1, 1, Conv, [512, 3, 2]]
43 | - [[-1, 9], 1, Concat, [1]] # cat head P5
44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45 |
46 | - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
47 |
--------------------------------------------------------------------------------
/ultralytics/cfg/models/v8/yolov8.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0n head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, Conv, [256, 3, 2]]
39 | - [[-1, 12], 1, Concat, [1]] # cat head P4
40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41 |
42 | - [-1, 1, Conv, [512, 3, 2]]
43 | - [[-1, 9], 1, Concat, [1]] # cat head P5
44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45 |
46 | - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
47 |
--------------------------------------------------------------------------------
/ultralytics/cfg/trackers/botsort.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT
3 |
4 | tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
5 | track_high_thresh: 0.5 # threshold for the first association
6 | track_low_thresh: 0.1 # threshold for the second association
7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8 | track_buffer: 30 # buffer to calculate the time when to remove tracks
9 | match_thresh: 0.8 # threshold for matching tracks
10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11 | # mot20: False # for tracker evaluation(not used for now)
12 |
13 | # BoT-SORT settings
14 | cmc_method: sparseOptFlow # method of global motion compensation
15 | # ReID model related thresh (not supported yet)
16 | proximity_thresh: 0.5
17 | appearance_thresh: 0.25
18 | with_reid: False
19 |
--------------------------------------------------------------------------------
/ultralytics/cfg/trackers/bytetrack.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
3 |
4 | tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
5 | track_high_thresh: 0.5 # threshold for the first association
6 | track_low_thresh: 0.1 # threshold for the second association
7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8 | track_buffer: 30 # buffer to calculate the time when to remove tracks
9 | match_thresh: 0.8 # threshold for matching tracks
10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11 | # mot20: False # for tracker evaluation(not used for now)
12 |
--------------------------------------------------------------------------------
/ultralytics/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .base import BaseDataset
4 | from .build import build_dataloader, build_yolo_dataset, load_inference_source
5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
6 |
7 | __all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
8 | 'build_dataloader', 'load_inference_source')
9 |
--------------------------------------------------------------------------------
/ultralytics/data/annotator.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from pathlib import Path
4 |
5 | from ultralytics import SAM, YOLO
6 |
7 |
8 | def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
9 | """
10 | Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
11 | Args:
12 | data (str): Path to a folder containing images to be annotated.
13 | det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
14 | sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
15 | device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
16 | output_dir (str | None | optional): Directory to save the annotated results.
17 | Defaults to a 'labels' folder in the same directory as 'data'.
18 | """
19 | det_model = YOLO(det_model)
20 | sam_model = SAM(sam_model)
21 |
22 | if not output_dir:
23 | output_dir = Path(str(data)).parent / 'labels'
24 | Path(output_dir).mkdir(exist_ok=True, parents=True)
25 |
26 | det_results = det_model(data, stream=True, device=device)
27 |
28 | for result in det_results:
29 | class_ids = result.boxes.cls.int().tolist() # noqa
30 | if len(class_ids):
31 | boxes = result.boxes.xyxy # Boxes object for bbox outputs
32 | sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
33 | segments = sam_results[0].masks.xyn # noqa
34 |
35 | with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f:
36 | for i in range(len(segments)):
37 | s = segments[i]
38 | if len(s) == 0:
39 | continue
40 | segment = map(str, segments[i].reshape(-1).tolist())
41 | f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
42 |
--------------------------------------------------------------------------------
/ultralytics/data/scripts/download_weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Ultralytics YOLO 🚀, AGPL-3.0 license
3 | # Download latest models from https://github.com/ultralytics/assets/releases
4 | # Example usage: bash ultralytics/data/scripts/download_weights.sh
5 | # parent
6 | # └── weights
7 | # ├── yolov8n.pt ← downloads here
8 | # ├── yolov8s.pt
9 | # └── ...
10 |
11 | python - < bool:
69 | """
70 | Attempt to authenticate with the server using either id_token or API key.
71 |
72 | Returns:
73 | bool: True if authentication is successful, False otherwise.
74 | """
75 | try:
76 | header = self.get_auth_header()
77 | if header:
78 | r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
79 | if not r.json().get('success', False):
80 | raise ConnectionError('Unable to authenticate.')
81 | return True
82 | raise ConnectionError('User has not authenticated locally.')
83 | except ConnectionError:
84 | self.id_token = self.api_key = False # reset invalid
85 | LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
86 | return False
87 |
88 | def auth_with_cookies(self) -> bool:
89 | """
90 | Attempt to fetch authentication via cookies and set id_token.
91 | User must be logged in to HUB and running in a supported browser.
92 |
93 | Returns:
94 | bool: True if authentication is successful, False otherwise.
95 | """
96 | if not is_colab():
97 | return False # Currently only works with Colab
98 | try:
99 | authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
100 | if authn.get('success', False):
101 | self.id_token = authn.get('data', {}).get('idToken', None)
102 | self.authenticate()
103 | return True
104 | raise ConnectionError('Unable to fetch browser authentication details.')
105 | except ConnectionError:
106 | self.id_token = False # reset invalid
107 | return False
108 |
109 | def get_auth_header(self):
110 | """
111 | Get the authentication header for making API requests.
112 |
113 | Returns:
114 | (dict): The authentication header if id_token or API key is set, None otherwise.
115 | """
116 | if self.id_token:
117 | return {'authorization': f'Bearer {self.id_token}'}
118 | elif self.api_key:
119 | return {'x-api-key': self.api_key}
120 | else:
121 | return None
122 |
123 | def get_state(self) -> bool:
124 | """
125 | Get the authentication state.
126 |
127 | Returns:
128 | bool: True if either id_token or API key is set, False otherwise.
129 | """
130 | return self.id_token or self.api_key
131 |
132 | def set_api_key(self, key: str):
133 | """
134 | Set the API key for authentication.
135 |
136 | Args:
137 | key (str): The API key string.
138 | """
139 | self.api_key = key
140 |
--------------------------------------------------------------------------------
/ultralytics/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .rtdetr import RTDETR
4 | from .sam import SAM
5 | from .yolo import YOLO
6 |
7 | __all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
8 |
--------------------------------------------------------------------------------
/ultralytics/models/fastsam/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import FastSAM
4 | from .predict import FastSAMPredictor
5 | from .prompt import FastSAMPrompt
6 | from .val import FastSAMValidator
7 |
8 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
9 |
--------------------------------------------------------------------------------
/ultralytics/models/fastsam/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from pathlib import Path
4 |
5 | from ultralytics.engine.model import Model
6 |
7 | from .predict import FastSAMPredictor
8 | from .val import FastSAMValidator
9 |
10 |
11 | class FastSAM(Model):
12 | """
13 | FastSAM model interface.
14 |
15 | Example:
16 | ```python
17 | from ultralytics import FastSAM
18 |
19 | model = FastSAM('last.pt')
20 | results = model.predict('ultralytics/assets/bus.jpg')
21 | ```
22 | """
23 |
24 | def __init__(self, model='FastSAM-x.pt'):
25 | """Call the __init__ method of the parent class (YOLO) with the updated default model"""
26 | if str(model) == 'FastSAM.pt':
27 | model = 'FastSAM-x.pt'
28 | assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
29 | super().__init__(model=model, task='segment')
30 |
31 | @property
32 | def task_map(self):
33 | return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
34 |
--------------------------------------------------------------------------------
/ultralytics/models/fastsam/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.engine.results import Results
6 | from ultralytics.models.fastsam.utils import bbox_iou
7 | from ultralytics.models.yolo.detect.predict import DetectionPredictor
8 | from ultralytics.utils import DEFAULT_CFG, ops
9 |
10 |
11 | class FastSAMPredictor(DetectionPredictor):
12 |
13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
14 | super().__init__(cfg, overrides, _callbacks)
15 | self.args.task = 'segment'
16 |
17 | def postprocess(self, preds, img, orig_imgs):
18 | """TODO: filter by classes."""
19 | p = ops.non_max_suppression(preds[0],
20 | self.args.conf,
21 | self.args.iou,
22 | agnostic=self.args.agnostic_nms,
23 | max_det=self.args.max_det,
24 | nc=len(self.model.names),
25 | classes=self.args.classes)
26 | full_box = torch.zeros_like(p[0][0])
27 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
28 | full_box = full_box.view(1, -1)
29 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
30 | if critical_iou_index.numel() != 0:
31 | full_box[0][4] = p[0][critical_iou_index][:, 4]
32 | full_box[0][6:] = p[0][critical_iou_index][:, 6:]
33 | p[0][critical_iou_index] = full_box
34 | results = []
35 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
36 | for i, pred in enumerate(p):
37 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
38 | path = self.batch[0]
39 | img_path = path[i] if isinstance(path, list) else path
40 | if not len(pred): # save empty boxes
41 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
42 | continue
43 | if self.args.retina_masks:
44 | if not isinstance(orig_imgs, torch.Tensor):
45 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
46 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
47 | else:
48 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
49 | if not isinstance(orig_imgs, torch.Tensor):
50 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
51 | results.append(
52 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
53 | return results
54 |
--------------------------------------------------------------------------------
/ultralytics/models/fastsam/utils.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 |
6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
7 | """
8 | Adjust bounding boxes to stick to image border if they are within a certain threshold.
9 |
10 | Args:
11 | boxes (torch.Tensor): (n, 4)
12 | image_shape (tuple): (height, width)
13 | threshold (int): pixel threshold
14 |
15 | Returns:
16 | adjusted_boxes (torch.Tensor): adjusted bounding boxes
17 | """
18 |
19 | # Image dimensions
20 | h, w = image_shape
21 |
22 | # Adjust boxes
23 | boxes[boxes[:, 0] < threshold, 0] = 0 # x1
24 | boxes[boxes[:, 1] < threshold, 1] = 0 # y1
25 | boxes[boxes[:, 2] > w - threshold, 2] = w # x2
26 | boxes[boxes[:, 3] > h - threshold, 3] = h # y2
27 | return boxes
28 |
29 |
30 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
31 | """
32 | Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
33 |
34 | Args:
35 | box1 (torch.Tensor): (4, )
36 | boxes (torch.Tensor): (n, 4)
37 | iou_thres (float): IoU threshold
38 | image_shape (tuple): (height, width)
39 | raw_output (bool): If True, return the raw IoU values instead of the indices
40 |
41 | Returns:
42 | high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
43 | """
44 | boxes = adjust_bboxes_to_image_border(boxes, image_shape)
45 | # obtain coordinates for intersections
46 | x1 = torch.max(box1[0], boxes[:, 0])
47 | y1 = torch.max(box1[1], boxes[:, 1])
48 | x2 = torch.min(box1[2], boxes[:, 2])
49 | y2 = torch.min(box1[3], boxes[:, 3])
50 |
51 | # compute the area of intersection
52 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
53 |
54 | # compute the area of both individual boxes
55 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
56 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
57 |
58 | # compute the area of union
59 | union = box1_area + box2_area - intersection
60 |
61 | # compute the IoU
62 | iou = intersection / union # Should be shape (n, )
63 | if raw_output:
64 | return 0 if iou.numel() == 0 else iou
65 |
66 | # return indices of boxes with IoU > thres
67 | return torch.nonzero(iou > iou_thres).flatten()
68 |
--------------------------------------------------------------------------------
/ultralytics/models/fastsam/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.models.yolo.segment import SegmentationValidator
4 | from ultralytics.utils.metrics import SegmentMetrics
5 |
6 |
7 | class FastSAMValidator(SegmentationValidator):
8 |
9 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
10 | """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
11 | super().__init__(dataloader, save_dir, pbar, args, _callbacks)
12 | self.args.task = 'segment'
13 | self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
14 | self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
15 |
--------------------------------------------------------------------------------
/ultralytics/models/nas/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import NAS
4 | from .predict import NASPredictor
5 | from .val import NASValidator
6 |
7 | __all__ = 'NASPredictor', 'NASValidator', 'NAS'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/nas/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | YOLO-NAS model interface.
4 |
5 | Example:
6 | ```python
7 | from ultralytics import NAS
8 |
9 | model = NAS('yolo_nas_s')
10 | results = model.predict('ultralytics/assets/bus.jpg')
11 | ```
12 | """
13 |
14 | from pathlib import Path
15 |
16 | import torch
17 |
18 | from ultralytics.engine.model import Model
19 | from ultralytics.utils.torch_utils import model_info, smart_inference_mode
20 |
21 | from .predict import NASPredictor
22 | from .val import NASValidator
23 |
24 |
25 | class NAS(Model):
26 |
27 | def __init__(self, model='yolo_nas_s.pt') -> None:
28 | assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
29 | super().__init__(model, task='detect')
30 |
31 | @smart_inference_mode()
32 | def _load(self, weights: str, task: str):
33 | # Load or create new NAS model
34 | import super_gradients
35 | suffix = Path(weights).suffix
36 | if suffix == '.pt':
37 | self.model = torch.load(weights)
38 | elif suffix == '':
39 | self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
40 | # Standardize model
41 | self.model.fuse = lambda verbose=True: self.model
42 | self.model.stride = torch.tensor([32])
43 | self.model.names = dict(enumerate(self.model._class_names))
44 | self.model.is_fused = lambda: False # for info()
45 | self.model.yaml = {} # for info()
46 | self.model.pt_path = weights # for export()
47 | self.model.task = 'detect' # for export()
48 |
49 | def info(self, detailed=False, verbose=True):
50 | """
51 | Logs model info.
52 |
53 | Args:
54 | detailed (bool): Show detailed information about model.
55 | verbose (bool): Controls verbosity.
56 | """
57 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
58 |
59 | @property
60 | def task_map(self):
61 | return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
62 |
--------------------------------------------------------------------------------
/ultralytics/models/nas/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.engine.predictor import BasePredictor
6 | from ultralytics.engine.results import Results
7 | from ultralytics.utils import ops
8 | from ultralytics.utils.ops import xyxy2xywh
9 |
10 |
11 | class NASPredictor(BasePredictor):
12 |
13 | def postprocess(self, preds_in, img, orig_imgs):
14 | """Postprocess predictions and returns a list of Results objects."""
15 |
16 | # Cat boxes and class scores
17 | boxes = xyxy2xywh(preds_in[0][0])
18 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
19 |
20 | preds = ops.non_max_suppression(preds,
21 | self.args.conf,
22 | self.args.iou,
23 | agnostic=self.args.agnostic_nms,
24 | max_det=self.args.max_det,
25 | classes=self.args.classes)
26 |
27 | results = []
28 | for i, pred in enumerate(preds):
29 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
30 | if not isinstance(orig_imgs, torch.Tensor):
31 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
32 | path = self.batch[0]
33 | img_path = path[i] if isinstance(path, list) else path
34 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
35 | return results
36 |
--------------------------------------------------------------------------------
/ultralytics/models/nas/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.models.yolo.detect import DetectionValidator
6 | from ultralytics.utils import ops
7 | from ultralytics.utils.ops import xyxy2xywh
8 |
9 | __all__ = ['NASValidator']
10 |
11 |
12 | class NASValidator(DetectionValidator):
13 |
14 | def postprocess(self, preds_in):
15 | """Apply Non-maximum suppression to prediction outputs."""
16 | boxes = xyxy2xywh(preds_in[0][0])
17 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
18 | return ops.non_max_suppression(preds,
19 | self.args.conf,
20 | self.args.iou,
21 | labels=self.lb,
22 | multi_label=False,
23 | agnostic=self.args.single_cls,
24 | max_det=self.args.max_det,
25 | max_time_img=0.5)
26 |
--------------------------------------------------------------------------------
/ultralytics/models/rtdetr/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import RTDETR
4 | from .predict import RTDETRPredictor
5 | from .val import RTDETRValidator
6 |
7 | __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/rtdetr/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | RT-DETR model interface
4 | """
5 | from ultralytics.engine.model import Model
6 | from ultralytics.nn.tasks import RTDETRDetectionModel
7 |
8 | from .predict import RTDETRPredictor
9 | from .train import RTDETRTrainer
10 | from .val import RTDETRValidator
11 |
12 |
13 | class RTDETR(Model):
14 | """
15 | RTDETR model interface.
16 | """
17 |
18 | def __init__(self, model='rtdetr-l.pt') -> None:
19 | if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
20 | raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
21 | super().__init__(model=model, task='detect')
22 |
23 | @property
24 | def task_map(self):
25 | return {
26 | 'detect': {
27 | 'predictor': RTDETRPredictor,
28 | 'validator': RTDETRValidator,
29 | 'trainer': RTDETRTrainer,
30 | 'model': RTDETRDetectionModel}}
31 |
--------------------------------------------------------------------------------
/ultralytics/models/rtdetr/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.data.augment import LetterBox
6 | from ultralytics.engine.predictor import BasePredictor
7 | from ultralytics.engine.results import Results
8 | from ultralytics.utils import ops
9 |
10 |
11 | class RTDETRPredictor(BasePredictor):
12 |
13 | def postprocess(self, preds, img, orig_imgs):
14 | """Postprocess predictions and returns a list of Results objects."""
15 | nd = preds[0].shape[-1]
16 | bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
17 | results = []
18 | for i, bbox in enumerate(bboxes): # (300, 4)
19 | bbox = ops.xywh2xyxy(bbox)
20 | score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
21 | idx = score.squeeze(-1) > self.args.conf # (300, )
22 | if self.args.classes is not None:
23 | idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
24 | pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
26 | oh, ow = orig_img.shape[:2]
27 | if not isinstance(orig_imgs, torch.Tensor):
28 | pred[..., [0, 2]] *= ow
29 | pred[..., [1, 3]] *= oh
30 | path = self.batch[0]
31 | img_path = path[i] if isinstance(path, list) else path
32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
33 | return results
34 |
35 | def pre_transform(self, im):
36 | """Pre-transform input image before inference.
37 |
38 | Args:
39 | im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
40 |
41 | Return: A list of transformed imgs.
42 | """
43 | # The size must be square(640) and scaleFilled.
44 | return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]
45 |
--------------------------------------------------------------------------------
/ultralytics/models/rtdetr/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from copy import copy
4 |
5 | import torch
6 |
7 | from ultralytics.models.yolo.detect import DetectionTrainer
8 | from ultralytics.nn.tasks import RTDETRDetectionModel
9 | from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
10 |
11 | from .val import RTDETRDataset, RTDETRValidator
12 |
13 |
14 | class RTDETRTrainer(DetectionTrainer):
15 |
16 | def get_model(self, cfg=None, weights=None, verbose=True):
17 | """Return a YOLO detection model."""
18 | model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
19 | if weights:
20 | model.load(weights)
21 | return model
22 |
23 | def build_dataset(self, img_path, mode='val', batch=None):
24 | """Build RTDETR Dataset
25 |
26 | Args:
27 | img_path (str): Path to the folder containing images.
28 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
29 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
30 | """
31 | return RTDETRDataset(
32 | img_path=img_path,
33 | imgsz=self.args.imgsz,
34 | batch_size=batch,
35 | augment=mode == 'train', # no augmentation
36 | hyp=self.args,
37 | rect=False, # no rect
38 | cache=self.args.cache or None,
39 | prefix=colorstr(f'{mode}: '),
40 | data=self.data)
41 |
42 | def get_validator(self):
43 | """Returns a DetectionValidator for RTDETR model validation."""
44 | self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
45 | return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
46 |
47 | def preprocess_batch(self, batch):
48 | """Preprocesses a batch of images by scaling and converting to float."""
49 | batch = super().preprocess_batch(batch)
50 | bs = len(batch['img'])
51 | batch_idx = batch['batch_idx']
52 | gt_bbox, gt_class = [], []
53 | for i in range(bs):
54 | gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
55 | gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
56 | return batch
57 |
58 |
59 | def train(cfg=DEFAULT_CFG, use_python=False):
60 | """Train and optimize RTDETR model given training data and device."""
61 | model = 'rtdetr-l.yaml'
62 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
63 | device = cfg.device if cfg.device is not None else ''
64 |
65 | # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
66 | # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
67 | args = dict(model=model,
68 | data=data,
69 | device=device,
70 | imgsz=640,
71 | exist_ok=True,
72 | batch=4,
73 | deterministic=False,
74 | amp=False)
75 | trainer = RTDETRTrainer(overrides=args)
76 | trainer.train()
77 |
78 |
79 | if __name__ == '__main__':
80 | train()
81 |
--------------------------------------------------------------------------------
/ultralytics/models/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import SAM
4 | from .predict import Predictor
5 |
6 | # from .build import build_sam
7 |
8 | __all__ = 'SAM', 'Predictor' # tuple or list
9 |
--------------------------------------------------------------------------------
/ultralytics/models/sam/build.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from functools import partial
10 |
11 | import torch
12 |
13 | from ultralytics.utils.downloads import attempt_download_asset
14 |
15 | from .modules.decoders import MaskDecoder
16 | from .modules.encoders import ImageEncoderViT, PromptEncoder
17 | from .modules.sam import Sam
18 | from .modules.tiny_encoder import TinyViT
19 | from .modules.transformer import TwoWayTransformer
20 |
21 |
22 | def build_sam_vit_h(checkpoint=None):
23 | """Build and return a Segment Anything Model (SAM) h-size model."""
24 | return _build_sam(
25 | encoder_embed_dim=1280,
26 | encoder_depth=32,
27 | encoder_num_heads=16,
28 | encoder_global_attn_indexes=[7, 15, 23, 31],
29 | checkpoint=checkpoint,
30 | )
31 |
32 |
33 | def build_sam_vit_l(checkpoint=None):
34 | """Build and return a Segment Anything Model (SAM) l-size model."""
35 | return _build_sam(
36 | encoder_embed_dim=1024,
37 | encoder_depth=24,
38 | encoder_num_heads=16,
39 | encoder_global_attn_indexes=[5, 11, 17, 23],
40 | checkpoint=checkpoint,
41 | )
42 |
43 |
44 | def build_sam_vit_b(checkpoint=None):
45 | """Build and return a Segment Anything Model (SAM) b-size model."""
46 | return _build_sam(
47 | encoder_embed_dim=768,
48 | encoder_depth=12,
49 | encoder_num_heads=12,
50 | encoder_global_attn_indexes=[2, 5, 8, 11],
51 | checkpoint=checkpoint,
52 | )
53 |
54 |
55 | def build_mobile_sam(checkpoint=None):
56 | """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
57 | return _build_sam(
58 | encoder_embed_dim=[64, 128, 160, 320],
59 | encoder_depth=[2, 2, 6, 2],
60 | encoder_num_heads=[2, 4, 5, 10],
61 | encoder_global_attn_indexes=None,
62 | mobile_sam=True,
63 | checkpoint=checkpoint,
64 | )
65 |
66 |
67 | def _build_sam(encoder_embed_dim,
68 | encoder_depth,
69 | encoder_num_heads,
70 | encoder_global_attn_indexes,
71 | checkpoint=None,
72 | mobile_sam=False):
73 | """Builds the selected SAM model architecture."""
74 | prompt_embed_dim = 256
75 | image_size = 1024
76 | vit_patch_size = 16
77 | image_embedding_size = image_size // vit_patch_size
78 | image_encoder = (TinyViT(
79 | img_size=1024,
80 | in_chans=3,
81 | num_classes=1000,
82 | embed_dims=encoder_embed_dim,
83 | depths=encoder_depth,
84 | num_heads=encoder_num_heads,
85 | window_sizes=[7, 7, 14, 7],
86 | mlp_ratio=4.0,
87 | drop_rate=0.0,
88 | drop_path_rate=0.0,
89 | use_checkpoint=False,
90 | mbconv_expand_ratio=4.0,
91 | local_conv_size=3,
92 | layer_lr_decay=0.8,
93 | ) if mobile_sam else ImageEncoderViT(
94 | depth=encoder_depth,
95 | embed_dim=encoder_embed_dim,
96 | img_size=image_size,
97 | mlp_ratio=4,
98 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
99 | num_heads=encoder_num_heads,
100 | patch_size=vit_patch_size,
101 | qkv_bias=True,
102 | use_rel_pos=True,
103 | global_attn_indexes=encoder_global_attn_indexes,
104 | window_size=14,
105 | out_chans=prompt_embed_dim,
106 | ))
107 | sam = Sam(
108 | image_encoder=image_encoder,
109 | prompt_encoder=PromptEncoder(
110 | embed_dim=prompt_embed_dim,
111 | image_embedding_size=(image_embedding_size, image_embedding_size),
112 | input_image_size=(image_size, image_size),
113 | mask_in_chans=16,
114 | ),
115 | mask_decoder=MaskDecoder(
116 | num_multimask_outputs=3,
117 | transformer=TwoWayTransformer(
118 | depth=2,
119 | embedding_dim=prompt_embed_dim,
120 | mlp_dim=2048,
121 | num_heads=8,
122 | ),
123 | transformer_dim=prompt_embed_dim,
124 | iou_head_depth=3,
125 | iou_head_hidden_dim=256,
126 | ),
127 | pixel_mean=[123.675, 116.28, 103.53],
128 | pixel_std=[58.395, 57.12, 57.375],
129 | )
130 | if checkpoint is not None:
131 | checkpoint = attempt_download_asset(checkpoint)
132 | with open(checkpoint, 'rb') as f:
133 | state_dict = torch.load(f)
134 | sam.load_state_dict(state_dict)
135 | sam.eval()
136 | # sam.load_state_dict(torch.load(checkpoint), strict=True)
137 | # sam.eval()
138 | return sam
139 |
140 |
141 | sam_model_map = {
142 | 'sam_h.pt': build_sam_vit_h,
143 | 'sam_l.pt': build_sam_vit_l,
144 | 'sam_b.pt': build_sam_vit_b,
145 | 'mobile_sam.pt': build_mobile_sam, }
146 |
147 |
148 | def build_sam(ckpt='sam_b.pt'):
149 | """Build a SAM model specified by ckpt."""
150 | model_builder = None
151 | for k in sam_model_map.keys():
152 | if ckpt.endswith(k):
153 | model_builder = sam_model_map.get(k)
154 |
155 | if not model_builder:
156 | raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
157 |
158 | return model_builder(ckpt)
159 |
--------------------------------------------------------------------------------
/ultralytics/models/sam/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | SAM model interface
4 | """
5 |
6 | from pathlib import Path
7 |
8 | from ultralytics.engine.model import Model
9 | from ultralytics.utils.torch_utils import model_info
10 |
11 | from .build import build_sam
12 | from .predict import Predictor
13 |
14 |
15 | class SAM(Model):
16 | """
17 | SAM model interface.
18 | """
19 |
20 | def __init__(self, model='sam_b.pt') -> None:
21 | if model and Path(model).suffix not in ('.pt', '.pth'):
22 | raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
23 | super().__init__(model=model, task='segment')
24 |
25 | def _load(self, weights: str, task=None):
26 | self.model = build_sam(weights)
27 |
28 | def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
29 | """Predicts and returns segmentation masks for given image or video source."""
30 | overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
31 | kwargs.update(overrides)
32 | prompts = dict(bboxes=bboxes, points=points, labels=labels)
33 | return super().predict(source, stream, prompts=prompts, **kwargs)
34 |
35 | def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
36 | """Calls the 'predict' function with given arguments to perform object detection."""
37 | return self.predict(source, stream, bboxes, points, labels, **kwargs)
38 |
39 | def info(self, detailed=False, verbose=True):
40 | """
41 | Logs model info.
42 |
43 | Args:
44 | detailed (bool): Show detailed information about model.
45 | verbose (bool): Controls verbosity.
46 | """
47 | return model_info(self.model, detailed=detailed, verbose=verbose)
48 |
49 | @property
50 | def task_map(self):
51 | return {'segment': {'predictor': Predictor}}
52 |
--------------------------------------------------------------------------------
/ultralytics/models/sam/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
--------------------------------------------------------------------------------
/ultralytics/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.models.yolo import classify, detect, pose, segment
4 |
5 | from .model import YOLO
6 |
7 | __all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/classify/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
4 | from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
5 | from ultralytics.models.yolo.classify.val import ClassificationValidator, val
6 |
7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/classify/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.engine.predictor import BasePredictor
6 | from ultralytics.engine.results import Results
7 | from ultralytics.utils import DEFAULT_CFG, ROOT
8 |
9 |
10 | class ClassificationPredictor(BasePredictor):
11 |
12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
13 | super().__init__(cfg, overrides, _callbacks)
14 | self.args.task = 'classify'
15 |
16 | def preprocess(self, img):
17 | """Converts input image to model-compatible data type."""
18 | if not isinstance(img, torch.Tensor):
19 | img = torch.stack([self.transforms(im) for im in img], dim=0)
20 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
21 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
22 |
23 | def postprocess(self, preds, img, orig_imgs):
24 | """Post-processes predictions to return Results objects."""
25 | results = []
26 | for i, pred in enumerate(preds):
27 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
28 | path = self.batch[0]
29 | img_path = path[i] if isinstance(path, list) else path
30 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
31 |
32 | return results
33 |
34 |
35 | def predict(cfg=DEFAULT_CFG, use_python=False):
36 | """Run YOLO model predictions on input images/videos."""
37 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
38 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
39 | else 'https://ultralytics.com/images/bus.jpg'
40 |
41 | args = dict(model=model, source=source)
42 | if use_python:
43 | from ultralytics import YOLO
44 | YOLO(model)(**args)
45 | else:
46 | predictor = ClassificationPredictor(overrides=args)
47 | predictor.predict_cli()
48 |
49 |
50 | if __name__ == '__main__':
51 | predict()
52 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/classify/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.data import ClassificationDataset, build_dataloader
6 | from ultralytics.engine.validator import BaseValidator
7 | from ultralytics.utils import DEFAULT_CFG, LOGGER
8 | from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
9 | from ultralytics.utils.plotting import plot_images
10 |
11 |
12 | class ClassificationValidator(BaseValidator):
13 |
14 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
15 | """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
16 | super().__init__(dataloader, save_dir, pbar, args, _callbacks)
17 | self.targets = None
18 | self.pred = None
19 | self.args.task = 'classify'
20 | self.metrics = ClassifyMetrics()
21 |
22 | def get_desc(self):
23 | """Returns a formatted string summarizing classification metrics."""
24 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
25 |
26 | def init_metrics(self, model):
27 | """Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
28 | self.names = model.names
29 | self.nc = len(model.names)
30 | self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
31 | self.pred = []
32 | self.targets = []
33 |
34 | def preprocess(self, batch):
35 | """Preprocesses input batch and returns it."""
36 | batch['img'] = batch['img'].to(self.device, non_blocking=True)
37 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
38 | batch['cls'] = batch['cls'].to(self.device)
39 | return batch
40 |
41 | def update_metrics(self, preds, batch):
42 | """Updates running metrics with model predictions and batch targets."""
43 | n5 = min(len(self.model.names), 5)
44 | self.pred.append(preds.argsort(1, descending=True)[:, :n5])
45 | self.targets.append(batch['cls'])
46 |
47 | def finalize_metrics(self, *args, **kwargs):
48 | """Finalizes metrics of the model such as confusion_matrix and speed."""
49 | self.confusion_matrix.process_cls_preds(self.pred, self.targets)
50 | if self.args.plots:
51 | for normalize in True, False:
52 | self.confusion_matrix.plot(save_dir=self.save_dir,
53 | names=self.names.values(),
54 | normalize=normalize,
55 | on_plot=self.on_plot)
56 | self.metrics.speed = self.speed
57 | self.metrics.confusion_matrix = self.confusion_matrix
58 |
59 | def get_stats(self):
60 | """Returns a dictionary of metrics obtained by processing targets and predictions."""
61 | self.metrics.process(self.targets, self.pred)
62 | return self.metrics.results_dict
63 |
64 | def build_dataset(self, img_path):
65 | return ClassificationDataset(root=img_path, args=self.args, augment=False)
66 |
67 | def get_dataloader(self, dataset_path, batch_size):
68 | """Builds and returns a data loader for classification tasks with given parameters."""
69 | dataset = self.build_dataset(dataset_path)
70 | return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
71 |
72 | def print_results(self):
73 | """Prints evaluation metrics for YOLO object detection model."""
74 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
75 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
76 |
77 | def plot_val_samples(self, batch, ni):
78 | """Plot validation image samples."""
79 | plot_images(
80 | images=batch['img'],
81 | batch_idx=torch.arange(len(batch['img'])),
82 | cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
83 | fname=self.save_dir / f'val_batch{ni}_labels.jpg',
84 | names=self.names,
85 | on_plot=self.on_plot)
86 |
87 | def plot_predictions(self, batch, preds, ni):
88 | """Plots predicted bounding boxes on input images and saves the result."""
89 | plot_images(batch['img'],
90 | batch_idx=torch.arange(len(batch['img'])),
91 | cls=torch.argmax(preds, dim=1),
92 | fname=self.save_dir / f'val_batch{ni}_pred.jpg',
93 | names=self.names,
94 | on_plot=self.on_plot) # pred
95 |
96 |
97 | def val(cfg=DEFAULT_CFG, use_python=False):
98 | """Validate YOLO model using custom data."""
99 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
100 | data = cfg.data or 'mnist160'
101 |
102 | args = dict(model=model, data=data)
103 | if use_python:
104 | from ultralytics import YOLO
105 | YOLO(model).val(**args)
106 | else:
107 | validator = ClassificationValidator(args=args)
108 | validator(model=args['model'])
109 |
110 |
111 | if __name__ == '__main__':
112 | val()
113 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/detect/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .predict import DetectionPredictor, predict
4 | from .train import DetectionTrainer, train
5 | from .val import DetectionValidator, val
6 |
7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/detect/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.engine.predictor import BasePredictor
6 | from ultralytics.engine.results import Results
7 | from ultralytics.utils import DEFAULT_CFG, ROOT, ops
8 |
9 |
10 | class DetectionPredictor(BasePredictor):
11 |
12 | def postprocess(self, preds, img, orig_imgs):
13 | """Post-processes predictions and returns a list of Results objects."""
14 | preds = ops.non_max_suppression(preds,
15 | self.args.conf,
16 | self.args.iou,
17 | agnostic=self.args.agnostic_nms,
18 | max_det=self.args.max_det,
19 | classes=self.args.classes)
20 |
21 | results = []
22 | for i, pred in enumerate(preds):
23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
24 | if not isinstance(orig_imgs, torch.Tensor):
25 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
26 | path = self.batch[0]
27 | img_path = path[i] if isinstance(path, list) else path
28 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
29 | return results
30 |
31 |
32 | def predict(cfg=DEFAULT_CFG, use_python=False):
33 | """Runs YOLO model inference on input image(s)."""
34 | model = cfg.model or 'yolov8n.pt'
35 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
36 | else 'https://ultralytics.com/images/bus.jpg'
37 |
38 | args = dict(model=model, source=source)
39 | if use_python:
40 | from ultralytics import YOLO
41 | YOLO(model)(**args)
42 | else:
43 | predictor = DetectionPredictor(overrides=args)
44 | predictor.predict_cli()
45 |
46 |
47 | if __name__ == '__main__':
48 | predict()
49 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/detect/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from copy import copy
4 |
5 | import numpy as np
6 |
7 | from ultralytics.data import build_dataloader, build_yolo_dataset
8 | from ultralytics.engine.trainer import BaseTrainer
9 | from ultralytics.models import yolo
10 | from ultralytics.nn.tasks import DetectionModel
11 | from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
12 | from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
13 | from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
14 |
15 |
16 | class DetectionTrainer(BaseTrainer):
17 |
18 | def build_dataset(self, img_path, mode='train', batch=None):
19 | """
20 | Build YOLO Dataset.
21 |
22 | Args:
23 | img_path (str): Path to the folder containing images.
24 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
25 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
26 | """
27 | gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
28 | return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
29 |
30 | def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
31 | """Construct and return dataloader."""
32 | assert mode in ['train', 'val']
33 | with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
34 | dataset = self.build_dataset(dataset_path, mode, batch_size)
35 | shuffle = mode == 'train'
36 | if getattr(dataset, 'rect', False) and shuffle:
37 | LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
38 | shuffle = False
39 | workers = self.args.workers if mode == 'train' else self.args.workers * 2
40 | return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
41 |
42 | def preprocess_batch(self, batch):
43 | """Preprocesses a batch of images by scaling and converting to float."""
44 | batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
45 | return batch
46 |
47 | def set_model_attributes(self):
48 | """nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
49 | # self.args.box *= 3 / nl # scale to layers
50 | # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
51 | # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
52 | self.model.nc = self.data['nc'] # attach number of classes to model
53 | self.model.names = self.data['names'] # attach class names to model
54 | self.model.args = self.args # attach hyperparameters to model
55 | # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
56 |
57 | def get_model(self, cfg=None, weights=None, verbose=True):
58 | """Return a YOLO detection model."""
59 | model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
60 | if weights:
61 | model.load(weights)
62 | return model
63 |
64 | def get_validator(self):
65 | """Returns a DetectionValidator for YOLO model validation."""
66 | self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
67 | return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
68 |
69 | def label_loss_items(self, loss_items=None, prefix='train'):
70 | """
71 | Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
72 | segmentation & detection
73 | """
74 | keys = [f'{prefix}/{x}' for x in self.loss_names]
75 | if loss_items is not None:
76 | loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
77 | return dict(zip(keys, loss_items))
78 | else:
79 | return keys
80 |
81 | def progress_string(self):
82 | """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
83 | return ('\n' + '%11s' *
84 | (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
85 |
86 | def plot_training_samples(self, batch, ni):
87 | """Plots training samples with their annotations."""
88 | plot_images(images=batch['img'],
89 | batch_idx=batch['batch_idx'],
90 | cls=batch['cls'].squeeze(-1),
91 | bboxes=batch['bboxes'],
92 | paths=batch['im_file'],
93 | fname=self.save_dir / f'train_batch{ni}.jpg',
94 | on_plot=self.on_plot)
95 |
96 | def plot_metrics(self):
97 | """Plots metrics from a CSV file."""
98 | plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
99 |
100 | def plot_training_labels(self):
101 | """Create a labeled training plot of the YOLO model."""
102 | boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
103 | cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
104 | plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
105 |
106 |
107 | def train(cfg=DEFAULT_CFG, use_python=False):
108 | """Train and optimize YOLO model given training data and device."""
109 | model = cfg.model or 'yolov8n.pt'
110 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
111 | device = cfg.device if cfg.device is not None else ''
112 |
113 | args = dict(model=model, data=data, device=device)
114 | if use_python:
115 | from ultralytics import YOLO
116 | YOLO(model).train(**args)
117 | else:
118 | trainer = DetectionTrainer(overrides=args)
119 | trainer.train()
120 |
121 |
122 | if __name__ == '__main__':
123 | train()
124 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.engine.model import Model
4 | from ultralytics.models import yolo # noqa
5 | from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
6 |
7 |
8 | class YOLO(Model):
9 | """
10 | YOLO (You Only Look Once) object detection model.
11 | """
12 |
13 | @property
14 | def task_map(self):
15 | """Map head to model, trainer, validator, and predictor classes"""
16 | return {
17 | 'classify': {
18 | 'model': ClassificationModel,
19 | 'trainer': yolo.classify.ClassificationTrainer,
20 | 'validator': yolo.classify.ClassificationValidator,
21 | 'predictor': yolo.classify.ClassificationPredictor, },
22 | 'detect': {
23 | 'model': DetectionModel,
24 | 'trainer': yolo.detect.DetectionTrainer,
25 | 'validator': yolo.detect.DetectionValidator,
26 | 'predictor': yolo.detect.DetectionPredictor, },
27 | 'segment': {
28 | 'model': SegmentationModel,
29 | 'trainer': yolo.segment.SegmentationTrainer,
30 | 'validator': yolo.segment.SegmentationValidator,
31 | 'predictor': yolo.segment.SegmentationPredictor, },
32 | 'pose': {
33 | 'model': PoseModel,
34 | 'trainer': yolo.pose.PoseTrainer,
35 | 'validator': yolo.pose.PoseValidator,
36 | 'predictor': yolo.pose.PosePredictor, }, }
37 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/pose/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .predict import PosePredictor, predict
4 | from .train import PoseTrainer, train
5 | from .val import PoseValidator, val
6 |
7 | __all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/pose/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.engine.results import Results
4 | from ultralytics.models.yolo.detect.predict import DetectionPredictor
5 | from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, ops
6 |
7 |
8 | class PosePredictor(DetectionPredictor):
9 |
10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11 | super().__init__(cfg, overrides, _callbacks)
12 | self.args.task = 'pose'
13 | if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
14 | LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
15 | 'See https://github.com/ultralytics/ultralytics/issues/4031.')
16 |
17 | def postprocess(self, preds, img, orig_imgs):
18 | """Return detection results for a given input image or list of images."""
19 | preds = ops.non_max_suppression(preds,
20 | self.args.conf,
21 | self.args.iou,
22 | agnostic=self.args.agnostic_nms,
23 | max_det=self.args.max_det,
24 | classes=self.args.classes,
25 | nc=len(self.model.names))
26 |
27 | results = []
28 | for i, pred in enumerate(preds):
29 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
30 | shape = orig_img.shape
31 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
32 | pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
33 | pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape)
34 | path = self.batch[0]
35 | img_path = path[i] if isinstance(path, list) else path
36 | results.append(
37 | Results(orig_img=orig_img,
38 | path=img_path,
39 | names=self.model.names,
40 | boxes=pred[:, :6],
41 | keypoints=pred_kpts))
42 | return results
43 |
44 |
45 | def predict(cfg=DEFAULT_CFG, use_python=False):
46 | """Runs YOLO to predict objects in an image or video."""
47 | model = cfg.model or 'yolov8n-pose.pt'
48 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
49 | else 'https://ultralytics.com/images/bus.jpg'
50 |
51 | args = dict(model=model, source=source)
52 | if use_python:
53 | from ultralytics import YOLO
54 | YOLO(model)(**args)
55 | else:
56 | predictor = PosePredictor(overrides=args)
57 | predictor.predict_cli()
58 |
59 |
60 | if __name__ == '__main__':
61 | predict()
62 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/pose/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from copy import copy
4 |
5 | from ultralytics.models import yolo
6 | from ultralytics.nn.tasks import PoseModel
7 | from ultralytics.utils import DEFAULT_CFG, LOGGER
8 | from ultralytics.utils.plotting import plot_images, plot_results
9 |
10 |
11 | class PoseTrainer(yolo.detect.DetectionTrainer):
12 |
13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
14 | """Initialize a PoseTrainer object with specified configurations and overrides."""
15 | if overrides is None:
16 | overrides = {}
17 | overrides['task'] = 'pose'
18 | super().__init__(cfg, overrides, _callbacks)
19 |
20 | if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
21 | LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
22 | 'See https://github.com/ultralytics/ultralytics/issues/4031.')
23 |
24 | def get_model(self, cfg=None, weights=None, verbose=True):
25 | """Get pose estimation model with specified configuration and weights."""
26 | model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
27 | if weights:
28 | model.load(weights)
29 |
30 | return model
31 |
32 | def set_model_attributes(self):
33 | """Sets keypoints shape attribute of PoseModel."""
34 | super().set_model_attributes()
35 | self.model.kpt_shape = self.data['kpt_shape']
36 |
37 | def get_validator(self):
38 | """Returns an instance of the PoseValidator class for validation."""
39 | self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
40 | return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
41 |
42 | def plot_training_samples(self, batch, ni):
43 | """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
44 | images = batch['img']
45 | kpts = batch['keypoints']
46 | cls = batch['cls'].squeeze(-1)
47 | bboxes = batch['bboxes']
48 | paths = batch['im_file']
49 | batch_idx = batch['batch_idx']
50 | plot_images(images,
51 | batch_idx,
52 | cls,
53 | bboxes,
54 | kpts=kpts,
55 | paths=paths,
56 | fname=self.save_dir / f'train_batch{ni}.jpg',
57 | on_plot=self.on_plot)
58 |
59 | def plot_metrics(self):
60 | """Plots training/val metrics."""
61 | plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
62 |
63 |
64 | def train(cfg=DEFAULT_CFG, use_python=False):
65 | """Train the YOLO model on the given data and device."""
66 | model = cfg.model or 'yolov8n-pose.yaml'
67 | data = cfg.data or 'coco8-pose.yaml'
68 | device = cfg.device if cfg.device is not None else ''
69 |
70 | args = dict(model=model, data=data, device=device)
71 | if use_python:
72 | from ultralytics import YOLO
73 | YOLO(model).train(**args)
74 | else:
75 | trainer = PoseTrainer(overrides=args)
76 | trainer.train()
77 |
78 |
79 | if __name__ == '__main__':
80 | train()
81 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/segment/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .predict import SegmentationPredictor, predict
4 | from .train import SegmentationTrainer, train
5 | from .val import SegmentationValidator, val
6 |
7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/segment/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.engine.results import Results
6 | from ultralytics.models.yolo.detect.predict import DetectionPredictor
7 | from ultralytics.utils import DEFAULT_CFG, ROOT, ops
8 |
9 |
10 | class SegmentationPredictor(DetectionPredictor):
11 |
12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
13 | super().__init__(cfg, overrides, _callbacks)
14 | self.args.task = 'segment'
15 |
16 | def postprocess(self, preds, img, orig_imgs):
17 | """TODO: filter by classes."""
18 | p = ops.non_max_suppression(preds[0],
19 | self.args.conf,
20 | self.args.iou,
21 | agnostic=self.args.agnostic_nms,
22 | max_det=self.args.max_det,
23 | nc=len(self.model.names),
24 | classes=self.args.classes)
25 | results = []
26 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
27 | for i, pred in enumerate(p):
28 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
29 | path = self.batch[0]
30 | img_path = path[i] if isinstance(path, list) else path
31 | if not len(pred): # save empty boxes
32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
33 | continue
34 | if self.args.retina_masks:
35 | if not isinstance(orig_imgs, torch.Tensor):
36 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
37 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
38 | else:
39 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
40 | if not isinstance(orig_imgs, torch.Tensor):
41 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
42 | results.append(
43 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
44 | return results
45 |
46 |
47 | def predict(cfg=DEFAULT_CFG, use_python=False):
48 | """Runs YOLO object detection on an image or video source."""
49 | model = cfg.model or 'yolov8n-seg.pt'
50 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
51 | else 'https://ultralytics.com/images/bus.jpg'
52 |
53 | args = dict(model=model, source=source)
54 | if use_python:
55 | from ultralytics import YOLO
56 | YOLO(model)(**args)
57 | else:
58 | predictor = SegmentationPredictor(overrides=args)
59 | predictor.predict_cli()
60 |
61 |
62 | if __name__ == '__main__':
63 | predict()
64 |
--------------------------------------------------------------------------------
/ultralytics/models/yolo/segment/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from copy import copy
4 |
5 | from ultralytics.models import yolo
6 | from ultralytics.nn.tasks import SegmentationModel
7 | from ultralytics.utils import DEFAULT_CFG, RANK
8 | from ultralytics.utils.plotting import plot_images, plot_results
9 |
10 |
11 | class SegmentationTrainer(yolo.detect.DetectionTrainer):
12 |
13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
14 | """Initialize a SegmentationTrainer object with given arguments."""
15 | if overrides is None:
16 | overrides = {}
17 | overrides['task'] = 'segment'
18 | super().__init__(cfg, overrides, _callbacks)
19 |
20 | def get_model(self, cfg=None, weights=None, verbose=True):
21 | """Return SegmentationModel initialized with specified config and weights."""
22 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
23 | if weights:
24 | model.load(weights)
25 |
26 | return model
27 |
28 | def get_validator(self):
29 | """Return an instance of SegmentationValidator for validation of YOLO model."""
30 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
31 | return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
32 |
33 | def plot_training_samples(self, batch, ni):
34 | """Creates a plot of training sample images with labels and box coordinates."""
35 | plot_images(batch['img'],
36 | batch['batch_idx'],
37 | batch['cls'].squeeze(-1),
38 | batch['bboxes'],
39 | batch['masks'],
40 | paths=batch['im_file'],
41 | fname=self.save_dir / f'train_batch{ni}.jpg',
42 | on_plot=self.on_plot)
43 |
44 | def plot_metrics(self):
45 | """Plots training/val metrics."""
46 | plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
47 |
48 |
49 | def train(cfg=DEFAULT_CFG, use_python=False):
50 | """Train a YOLO segmentation model based on passed arguments."""
51 | model = cfg.model or 'yolov8n-seg.pt'
52 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
53 | device = cfg.device if cfg.device is not None else ''
54 |
55 | args = dict(model=model, data=data, device=device)
56 | if use_python:
57 | from ultralytics import YOLO
58 | YOLO(model).train(**args)
59 | else:
60 | trainer = SegmentationTrainer(overrides=args)
61 | trainer.train()
62 |
63 |
64 | if __name__ == '__main__':
65 | train()
66 |
--------------------------------------------------------------------------------
/ultralytics/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
4 | attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
5 | yaml_model_load)
6 |
7 | __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
8 | 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
9 | 'BaseModel')
10 |
--------------------------------------------------------------------------------
/ultralytics/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Ultralytics modules. Visualize with:
4 |
5 | from ultralytics.nn.modules import *
6 | import torch
7 | import os
8 |
9 | x = torch.ones(1, 128, 40, 40)
10 | m = Conv(128, 128)
11 | f = f'{m._get_name()}.onnx'
12 | torch.onnx.export(m, x, f)
13 | os.system(f'onnxsim {f} {f} && open {f}')
14 | """
15 |
16 | from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
17 | HGBlock, HGStem, Proto, RepC3)
18 | from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
19 | GhostConv, LightConv, RepConv, SpatialAttention)
20 | from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
21 | from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
22 | MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
23 |
24 | __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
25 | 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
26 | 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
27 | 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
28 | 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
29 | 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
30 |
--------------------------------------------------------------------------------
/ultralytics/nn/modules/utils.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Module utils
4 | """
5 |
6 | import copy
7 | import math
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn.init import uniform_
14 |
15 | __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
16 |
17 |
18 | def _get_clones(module, n):
19 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
20 |
21 |
22 | def bias_init_with_prob(prior_prob=0.01):
23 | """initialize conv/fc bias value according to a given probability value."""
24 | return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
25 |
26 |
27 | def linear_init_(module):
28 | bound = 1 / math.sqrt(module.weight.shape[0])
29 | uniform_(module.weight, -bound, bound)
30 | if hasattr(module, 'bias') and module.bias is not None:
31 | uniform_(module.bias, -bound, bound)
32 |
33 |
34 | def inverse_sigmoid(x, eps=1e-5):
35 | x = x.clamp(min=0, max=1)
36 | x1 = x.clamp(min=eps)
37 | x2 = (1 - x).clamp(min=eps)
38 | return torch.log(x1 / x2)
39 |
40 |
41 | def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
42 | sampling_locations: torch.Tensor,
43 | attention_weights: torch.Tensor) -> torch.Tensor:
44 | """
45 | Multi-scale deformable attention.
46 | https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
47 | """
48 |
49 | bs, _, num_heads, embed_dims = value.shape
50 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
51 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
52 | sampling_grids = 2 * sampling_locations - 1
53 | sampling_value_list = []
54 | for level, (H_, W_) in enumerate(value_spatial_shapes):
55 | # bs, H_*W_, num_heads, embed_dims ->
56 | # bs, H_*W_, num_heads*embed_dims ->
57 | # bs, num_heads*embed_dims, H_*W_ ->
58 | # bs*num_heads, embed_dims, H_, W_
59 | value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
60 | # bs, num_queries, num_heads, num_points, 2 ->
61 | # bs, num_heads, num_queries, num_points, 2 ->
62 | # bs*num_heads, num_queries, num_points, 2
63 | sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
64 | # bs*num_heads, embed_dims, num_queries, num_points
65 | sampling_value_l_ = F.grid_sample(value_l_,
66 | sampling_grid_l_,
67 | mode='bilinear',
68 | padding_mode='zeros',
69 | align_corners=False)
70 | sampling_value_list.append(sampling_value_l_)
71 | # (bs, num_queries, num_heads, num_levels, num_points) ->
72 | # (bs, num_heads, num_queries, num_levels, num_points) ->
73 | # (bs, num_heads, 1, num_queries, num_levels*num_points)
74 | attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
75 | num_levels * num_points)
76 | output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
77 | bs, num_heads * embed_dims, num_queries))
78 | return output.transpose(1, 2).contiguous()
79 |
--------------------------------------------------------------------------------
/ultralytics/trackers/README.md:
--------------------------------------------------------------------------------
1 | # Tracker
2 |
3 | ## Supported Trackers
4 |
5 | - [x] ByteTracker
6 | - [x] BoT-SORT
7 |
8 | ## Usage
9 |
10 | ### python interface:
11 |
12 | You can use the Python interface to track objects using the YOLO model.
13 |
14 | ```python
15 | from ultralytics import YOLO
16 |
17 | model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
18 | model.track(
19 | source="video/streams",
20 | stream=True,
21 | tracker="botsort.yaml", # or 'bytetrack.yaml'
22 | show=True,
23 | )
24 | ```
25 |
26 | You can get the IDs of the tracked objects using the following code:
27 |
28 | ```python
29 | from ultralytics import YOLO
30 |
31 | model = YOLO("yolov8n.pt")
32 |
33 | for result in model.track(source="video.mp4"):
34 | print(
35 | result.boxes.id.cpu().numpy().astype(int)
36 | ) # this will print the IDs of the tracked objects in the frame
37 | ```
38 |
39 | If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking.
40 |
41 | ```python
42 | import cv2
43 | from ultralytics import YOLO
44 |
45 | cap = cv2.VideoCapture("video.mp4")
46 | model = YOLO("yolov8n.pt")
47 | while True:
48 | ret, frame = cap.read()
49 | if not ret:
50 | break
51 | results = model.track(frame, persist=True)
52 | boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
53 | ids = results[0].boxes.id.cpu().numpy().astype(int)
54 | for box, id in zip(boxes, ids):
55 | cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
56 | cv2.putText(
57 | frame,
58 | f"Id {id}",
59 | (box[0], box[1]),
60 | cv2.FONT_HERSHEY_SIMPLEX,
61 | 1,
62 | (0, 0, 255),
63 | 2,
64 | )
65 | cv2.imshow("frame", frame)
66 | if cv2.waitKey(1) & 0xFF == ord("q"):
67 | break
68 | ```
69 |
70 | ## Change tracker parameters
71 |
72 | You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder.
73 |
74 | ## Command Line Interface (CLI)
75 |
76 | You can also use the command line interface to track objects using the YOLO model.
77 |
78 | ```bash
79 | yolo detect track source=... tracker=...
80 | yolo segment track source=... tracker=...
81 | yolo pose track source=... tracker=...
82 | ```
83 |
84 | By default, trackers will use the configuration in `ultralytics/cfg/trackers`.
85 | We also support using a modified tracker config file. Please refer to the tracker config files
86 | in `ultralytics/cfg/trackers`.
87 |
--------------------------------------------------------------------------------
/ultralytics/trackers/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .bot_sort import BOTSORT
4 | from .byte_tracker import BYTETracker
5 | from .track import register_tracker
6 |
7 | __all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
8 |
--------------------------------------------------------------------------------
/ultralytics/trackers/basetrack.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 |
7 |
8 | class TrackState:
9 | """Enumeration of possible object tracking states."""
10 |
11 | New = 0
12 | Tracked = 1
13 | Lost = 2
14 | Removed = 3
15 |
16 |
17 | class BaseTrack:
18 | """Base class for object tracking, handling basic track attributes and operations."""
19 |
20 | _count = 0
21 |
22 | track_id = 0
23 | is_activated = False
24 | state = TrackState.New
25 |
26 | history = OrderedDict()
27 | features = []
28 | curr_feature = None
29 | score = 0
30 | start_frame = 0
31 | frame_id = 0
32 | time_since_update = 0
33 |
34 | # Multi-camera
35 | location = (np.inf, np.inf)
36 |
37 | @property
38 | def end_frame(self):
39 | """Return the last frame ID of the track."""
40 | return self.frame_id
41 |
42 | @staticmethod
43 | def next_id():
44 | """Increment and return the global track ID counter."""
45 | BaseTrack._count += 1
46 | return BaseTrack._count
47 |
48 | def activate(self, *args):
49 | """Activate the track with the provided arguments."""
50 | raise NotImplementedError
51 |
52 | def predict(self):
53 | """Predict the next state of the track."""
54 | raise NotImplementedError
55 |
56 | def update(self, *args, **kwargs):
57 | """Update the track with new observations."""
58 | raise NotImplementedError
59 |
60 | def mark_lost(self):
61 | """Mark the track as lost."""
62 | self.state = TrackState.Lost
63 |
64 | def mark_removed(self):
65 | """Mark the track as removed."""
66 | self.state = TrackState.Removed
67 |
68 | @staticmethod
69 | def reset_id():
70 | """Reset the global track ID counter."""
71 | BaseTrack._count = 0
72 |
--------------------------------------------------------------------------------
/ultralytics/trackers/bot_sort.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from collections import deque
4 |
5 | import numpy as np
6 |
7 | from .basetrack import TrackState
8 | from .byte_tracker import BYTETracker, STrack
9 | from .utils import matching
10 | from .utils.gmc import GMC
11 | from .utils.kalman_filter import KalmanFilterXYWH
12 |
13 |
14 | class BOTrack(STrack):
15 | shared_kalman = KalmanFilterXYWH()
16 |
17 | def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
18 | """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
19 | super().__init__(tlwh, score, cls)
20 |
21 | self.smooth_feat = None
22 | self.curr_feat = None
23 | if feat is not None:
24 | self.update_features(feat)
25 | self.features = deque([], maxlen=feat_history)
26 | self.alpha = 0.9
27 |
28 | def update_features(self, feat):
29 | """Update features vector and smooth it using exponential moving average."""
30 | feat /= np.linalg.norm(feat)
31 | self.curr_feat = feat
32 | if self.smooth_feat is None:
33 | self.smooth_feat = feat
34 | else:
35 | self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
36 | self.features.append(feat)
37 | self.smooth_feat /= np.linalg.norm(self.smooth_feat)
38 |
39 | def predict(self):
40 | """Predicts the mean and covariance using Kalman filter."""
41 | mean_state = self.mean.copy()
42 | if self.state != TrackState.Tracked:
43 | mean_state[6] = 0
44 | mean_state[7] = 0
45 |
46 | self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
47 |
48 | def re_activate(self, new_track, frame_id, new_id=False):
49 | """Reactivates a track with updated features and optionally assigns a new ID."""
50 | if new_track.curr_feat is not None:
51 | self.update_features(new_track.curr_feat)
52 | super().re_activate(new_track, frame_id, new_id)
53 |
54 | def update(self, new_track, frame_id):
55 | """Update the YOLOv8 instance with new track and frame ID."""
56 | if new_track.curr_feat is not None:
57 | self.update_features(new_track.curr_feat)
58 | super().update(new_track, frame_id)
59 |
60 | @property
61 | def tlwh(self):
62 | """Get current position in bounding box format `(top left x, top left y,
63 | width, height)`.
64 | """
65 | if self.mean is None:
66 | return self._tlwh.copy()
67 | ret = self.mean[:4].copy()
68 | ret[:2] -= ret[2:] / 2
69 | return ret
70 |
71 | @staticmethod
72 | def multi_predict(stracks):
73 | """Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
74 | if len(stracks) <= 0:
75 | return
76 | multi_mean = np.asarray([st.mean.copy() for st in stracks])
77 | multi_covariance = np.asarray([st.covariance for st in stracks])
78 | for i, st in enumerate(stracks):
79 | if st.state != TrackState.Tracked:
80 | multi_mean[i][6] = 0
81 | multi_mean[i][7] = 0
82 | multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
83 | for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
84 | stracks[i].mean = mean
85 | stracks[i].covariance = cov
86 |
87 | def convert_coords(self, tlwh):
88 | """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
89 | return self.tlwh_to_xywh(tlwh)
90 |
91 | @staticmethod
92 | def tlwh_to_xywh(tlwh):
93 | """Convert bounding box to format `(center x, center y, width,
94 | height)`.
95 | """
96 | ret = np.asarray(tlwh).copy()
97 | ret[:2] += ret[2:] / 2
98 | return ret
99 |
100 |
101 | class BOTSORT(BYTETracker):
102 |
103 | def __init__(self, args, frame_rate=30):
104 | """Initialize YOLOv8 object with ReID module and GMC algorithm."""
105 | super().__init__(args, frame_rate)
106 | # ReID module
107 | self.proximity_thresh = args.proximity_thresh
108 | self.appearance_thresh = args.appearance_thresh
109 |
110 | if args.with_reid:
111 | # Haven't supported BoT-SORT(reid) yet
112 | self.encoder = None
113 | # self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation])
114 | self.gmc = GMC(method=args.cmc_method)
115 |
116 | def get_kalmanfilter(self):
117 | """Returns an instance of KalmanFilterXYWH for object tracking."""
118 | return KalmanFilterXYWH()
119 |
120 | def init_track(self, dets, scores, cls, img=None):
121 | """Initialize track with detections, scores, and classes."""
122 | if len(dets) == 0:
123 | return []
124 | if self.args.with_reid and self.encoder is not None:
125 | features_keep = self.encoder.inference(img, dets)
126 | return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
127 | else:
128 | return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
129 |
130 | def get_dists(self, tracks, detections):
131 | """Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
132 | dists = matching.iou_distance(tracks, detections)
133 | dists_mask = (dists > self.proximity_thresh)
134 |
135 | # TODO: mot20
136 | # if not self.args.mot20:
137 | dists = matching.fuse_score(dists, detections)
138 |
139 | if self.args.with_reid and self.encoder is not None:
140 | emb_dists = matching.embedding_distance(tracks, detections) / 2.0
141 | emb_dists[emb_dists > self.appearance_thresh] = 1.0
142 | emb_dists[dists_mask] = 1.0
143 | dists = np.minimum(dists, emb_dists)
144 | return dists
145 |
146 | def multi_predict(self, tracks):
147 | """Predict and track multiple objects with YOLOv8 model."""
148 | BOTrack.multi_predict(tracks)
149 |
--------------------------------------------------------------------------------
/ultralytics/trackers/track.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from functools import partial
4 |
5 | import torch
6 |
7 | from ultralytics.utils import IterableSimpleNamespace, yaml_load
8 | from ultralytics.utils.checks import check_yaml
9 |
10 | from .bot_sort import BOTSORT
11 | from .byte_tracker import BYTETracker
12 |
13 | TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
14 |
15 |
16 | def on_predict_start(predictor, persist=False):
17 | """
18 | Initialize trackers for object tracking during prediction.
19 |
20 | Args:
21 | predictor (object): The predictor object to initialize trackers for.
22 | persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
23 |
24 | Raises:
25 | AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
26 | """
27 | if hasattr(predictor, 'trackers') and persist:
28 | return
29 | tracker = check_yaml(predictor.args.tracker)
30 | cfg = IterableSimpleNamespace(**yaml_load(tracker))
31 | assert cfg.tracker_type in ['bytetrack', 'botsort'], \
32 | f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
33 | trackers = []
34 | for _ in range(predictor.dataset.bs):
35 | tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
36 | trackers.append(tracker)
37 | predictor.trackers = trackers
38 |
39 |
40 | def on_predict_postprocess_end(predictor):
41 | """Postprocess detected boxes and update with object tracking."""
42 | bs = predictor.dataset.bs
43 | im0s = predictor.batch[1]
44 | for i in range(bs):
45 | det = predictor.results[i].boxes.cpu().numpy()
46 | if len(det) == 0:
47 | continue
48 | tracks = predictor.trackers[i].update(det, im0s[i])
49 | if len(tracks) == 0:
50 | continue
51 | idx = tracks[:, -1].astype(int)
52 | predictor.results[i] = predictor.results[i][idx]
53 | predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
54 |
55 |
56 | def register_tracker(model, persist):
57 | """
58 | Register tracking callbacks to the model for object tracking during prediction.
59 |
60 | Args:
61 | model (object): The model object to register tracking callbacks for.
62 | persist (bool): Whether to persist the trackers if they already exist.
63 |
64 | """
65 | model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
66 | model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
67 |
--------------------------------------------------------------------------------
/ultralytics/trackers/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
--------------------------------------------------------------------------------
/ultralytics/trackers/utils/matching.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import numpy as np
4 | import scipy
5 | from scipy.spatial.distance import cdist
6 |
7 | from ultralytics.utils.metrics import bbox_ioa
8 |
9 | try:
10 | import lap # for linear_assignment
11 |
12 | assert lap.__version__ # verify package is not directory
13 | except (ImportError, AssertionError, AttributeError):
14 | from ultralytics.utils.checks import check_requirements
15 |
16 | check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
17 | import lap
18 |
19 |
20 | def linear_assignment(cost_matrix, thresh, use_lap=True):
21 | """
22 | Perform linear assignment using scipy or lap.lapjv.
23 |
24 | Args:
25 | cost_matrix (np.ndarray): The matrix containing cost values for assignments.
26 | thresh (float): Threshold for considering an assignment valid.
27 | use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
28 |
29 | Returns:
30 | (tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'.
31 | """
32 |
33 | if cost_matrix.size == 0:
34 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
35 |
36 | if use_lap:
37 | _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
38 | matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
39 | unmatched_a = np.where(x < 0)[0]
40 | unmatched_b = np.where(y < 0)[0]
41 | else:
42 | # Scipy linear sum assignment is NOT working correctly, DO NOT USE
43 | y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
44 | matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
45 | unmatched = np.ones(cost_matrix.shape)
46 | for i, xi in matches:
47 | unmatched[i, xi] = 0.0
48 | unmatched_a = np.where(unmatched.all(1))[0]
49 | unmatched_b = np.where(unmatched.all(0))[0]
50 |
51 | return matches, unmatched_a, unmatched_b
52 |
53 |
54 | def iou_distance(atracks, btracks):
55 | """
56 | Compute cost based on Intersection over Union (IoU) between tracks.
57 |
58 | Args:
59 | atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
60 | btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes.
61 |
62 | Returns:
63 | (np.ndarray): Cost matrix computed based on IoU.
64 | """
65 |
66 | if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
67 | or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
68 | atlbrs = atracks
69 | btlbrs = btracks
70 | else:
71 | atlbrs = [track.tlbr for track in atracks]
72 | btlbrs = [track.tlbr for track in btracks]
73 |
74 | ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
75 | if len(atlbrs) and len(btlbrs):
76 | ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32),
77 | np.ascontiguousarray(btlbrs, dtype=np.float32),
78 | iou=True)
79 | return 1 - ious # cost matrix
80 |
81 |
82 | def embedding_distance(tracks, detections, metric='cosine'):
83 | """
84 | Compute distance between tracks and detections based on embeddings.
85 |
86 | Args:
87 | tracks (list[STrack]): List of tracks.
88 | detections (list[BaseTrack]): List of detections.
89 | metric (str, optional): Metric for distance computation. Defaults to 'cosine'.
90 |
91 | Returns:
92 | (np.ndarray): Cost matrix computed based on embeddings.
93 | """
94 |
95 | cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
96 | if cost_matrix.size == 0:
97 | return cost_matrix
98 | det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
99 | # for i, track in enumerate(tracks):
100 | # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
101 | track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
102 | cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features
103 | return cost_matrix
104 |
105 |
106 | def fuse_score(cost_matrix, detections):
107 | """
108 | Fuses cost matrix with detection scores to produce a single similarity matrix.
109 |
110 | Args:
111 | cost_matrix (np.ndarray): The matrix containing cost values for assignments.
112 | detections (list[BaseTrack]): List of detections with scores.
113 |
114 | Returns:
115 | (np.ndarray): Fused similarity matrix.
116 | """
117 |
118 | if cost_matrix.size == 0:
119 | return cost_matrix
120 | iou_sim = 1 - cost_matrix
121 | det_scores = np.array([det.score for det in detections])
122 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
123 | fuse_sim = iou_sim * det_scores
124 | return 1 - fuse_sim # fuse_cost
125 |
--------------------------------------------------------------------------------
/ultralytics/utils/autobatch.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
4 | """
5 |
6 | from copy import deepcopy
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
12 | from ultralytics.utils.torch_utils import profile
13 |
14 |
15 | def check_train_batch_size(model, imgsz=640, amp=True):
16 | """
17 | Check YOLO training batch size using the autobatch() function.
18 |
19 | Args:
20 | model (torch.nn.Module): YOLO model to check batch size for.
21 | imgsz (int): Image size used for training.
22 | amp (bool): If True, use automatic mixed precision (AMP) for training.
23 |
24 | Returns:
25 | (int): Optimal batch size computed using the autobatch() function.
26 | """
27 |
28 | with torch.cuda.amp.autocast(amp):
29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
30 |
31 |
32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=DEFAULT_CFG.batch):
33 | """
34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
35 |
36 | Args:
37 | model (torch.nn.module): YOLO model to compute batch size for.
38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
41 |
42 | Returns:
43 | (int): The optimal batch size.
44 | """
45 |
46 | # Check device
47 | prefix = colorstr('AutoBatch: ')
48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
49 | device = next(model.parameters()).device # get model device
50 | if device.type == 'cpu':
51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
52 | return batch_size
53 | if torch.backends.cudnn.benchmark:
54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
55 | return batch_size
56 |
57 | # Inspect CUDA memory
58 | gb = 1 << 30 # bytes to GiB (1024 ** 3)
59 | d = str(device).upper() # 'CUDA:0'
60 | properties = torch.cuda.get_device_properties(device) # device properties
61 | t = properties.total_memory / gb # GiB total
62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved
63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated
64 | f = t - (r + a) # GiB free
65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
66 |
67 | # Profile batch sizes
68 | batch_sizes = [1, 2, 4, 8, 16]
69 | try:
70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
71 | results = profile(img, model, n=3, device=device)
72 |
73 | # Fit a solution
74 | y = [x[2] for x in results if x] # memory [2]
75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
77 | if None in results: # some sizes failed
78 | i = results.index(None) # first fail index
79 | if b >= batch_sizes[i]: # y intercept above failure point
80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point
81 | if b < 1 or b > 1024: # b outside of safe range
82 | b = batch_size
83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
84 |
85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
87 | return b
88 | except Exception as e:
89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
90 | return batch_size
91 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
4 |
5 | __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
6 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/base.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Base callbacks
4 | """
5 |
6 | from collections import defaultdict
7 | from copy import deepcopy
8 |
9 | # Trainer callbacks ----------------------------------------------------------------------------------------------------
10 |
11 |
12 | def on_pretrain_routine_start(trainer):
13 | """Called before the pretraining routine starts."""
14 | pass
15 |
16 |
17 | def on_pretrain_routine_end(trainer):
18 | """Called after the pretraining routine ends."""
19 | pass
20 |
21 |
22 | def on_train_start(trainer):
23 | """Called when the training starts."""
24 | pass
25 |
26 |
27 | def on_train_epoch_start(trainer):
28 | """Called at the start of each training epoch."""
29 | pass
30 |
31 |
32 | def on_train_batch_start(trainer):
33 | """Called at the start of each training batch."""
34 | pass
35 |
36 |
37 | def optimizer_step(trainer):
38 | """Called when the optimizer takes a step."""
39 | pass
40 |
41 |
42 | def on_before_zero_grad(trainer):
43 | """Called before the gradients are set to zero."""
44 | pass
45 |
46 |
47 | def on_train_batch_end(trainer):
48 | """Called at the end of each training batch."""
49 | pass
50 |
51 |
52 | def on_train_epoch_end(trainer):
53 | """Called at the end of each training epoch."""
54 | pass
55 |
56 |
57 | def on_fit_epoch_end(trainer):
58 | """Called at the end of each fit epoch (train + val)."""
59 | pass
60 |
61 |
62 | def on_model_save(trainer):
63 | """Called when the model is saved."""
64 | pass
65 |
66 |
67 | def on_train_end(trainer):
68 | """Called when the training ends."""
69 | pass
70 |
71 |
72 | def on_params_update(trainer):
73 | """Called when the model parameters are updated."""
74 | pass
75 |
76 |
77 | def teardown(trainer):
78 | """Called during the teardown of the training process."""
79 | pass
80 |
81 |
82 | # Validator callbacks --------------------------------------------------------------------------------------------------
83 |
84 |
85 | def on_val_start(validator):
86 | """Called when the validation starts."""
87 | pass
88 |
89 |
90 | def on_val_batch_start(validator):
91 | """Called at the start of each validation batch."""
92 | pass
93 |
94 |
95 | def on_val_batch_end(validator):
96 | """Called at the end of each validation batch."""
97 | pass
98 |
99 |
100 | def on_val_end(validator):
101 | """Called when the validation ends."""
102 | pass
103 |
104 |
105 | # Predictor callbacks --------------------------------------------------------------------------------------------------
106 |
107 |
108 | def on_predict_start(predictor):
109 | """Called when the prediction starts."""
110 | pass
111 |
112 |
113 | def on_predict_batch_start(predictor):
114 | """Called at the start of each prediction batch."""
115 | pass
116 |
117 |
118 | def on_predict_batch_end(predictor):
119 | """Called at the end of each prediction batch."""
120 | pass
121 |
122 |
123 | def on_predict_postprocess_end(predictor):
124 | """Called after the post-processing of the prediction ends."""
125 | pass
126 |
127 |
128 | def on_predict_end(predictor):
129 | """Called when the prediction ends."""
130 | pass
131 |
132 |
133 | # Exporter callbacks ---------------------------------------------------------------------------------------------------
134 |
135 |
136 | def on_export_start(exporter):
137 | """Called when the model export starts."""
138 | pass
139 |
140 |
141 | def on_export_end(exporter):
142 | """Called when the model export ends."""
143 | pass
144 |
145 |
146 | default_callbacks = {
147 | # Run in trainer
148 | 'on_pretrain_routine_start': [on_pretrain_routine_start],
149 | 'on_pretrain_routine_end': [on_pretrain_routine_end],
150 | 'on_train_start': [on_train_start],
151 | 'on_train_epoch_start': [on_train_epoch_start],
152 | 'on_train_batch_start': [on_train_batch_start],
153 | 'optimizer_step': [optimizer_step],
154 | 'on_before_zero_grad': [on_before_zero_grad],
155 | 'on_train_batch_end': [on_train_batch_end],
156 | 'on_train_epoch_end': [on_train_epoch_end],
157 | 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
158 | 'on_model_save': [on_model_save],
159 | 'on_train_end': [on_train_end],
160 | 'on_params_update': [on_params_update],
161 | 'teardown': [teardown],
162 |
163 | # Run in validator
164 | 'on_val_start': [on_val_start],
165 | 'on_val_batch_start': [on_val_batch_start],
166 | 'on_val_batch_end': [on_val_batch_end],
167 | 'on_val_end': [on_val_end],
168 |
169 | # Run in predictor
170 | 'on_predict_start': [on_predict_start],
171 | 'on_predict_batch_start': [on_predict_batch_start],
172 | 'on_predict_postprocess_end': [on_predict_postprocess_end],
173 | 'on_predict_batch_end': [on_predict_batch_end],
174 | 'on_predict_end': [on_predict_end],
175 |
176 | # Run in exporter
177 | 'on_export_start': [on_export_start],
178 | 'on_export_end': [on_export_end]}
179 |
180 |
181 | def get_default_callbacks():
182 | """
183 | Return a copy of the default_callbacks dictionary with lists as default values.
184 |
185 | Returns:
186 | (defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
187 | """
188 | return defaultdict(list, deepcopy(default_callbacks))
189 |
190 |
191 | def add_integration_callbacks(instance):
192 | """
193 | Add integration callbacks from various sources to the instance's callbacks.
194 |
195 | Args:
196 | instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
197 | of callback lists.
198 | """
199 | from .clearml import callbacks as clearml_cb
200 | from .comet import callbacks as comet_cb
201 | from .dvc import callbacks as dvc_cb
202 | from .hub import callbacks as hub_cb
203 | from .mlflow import callbacks as mlflow_cb
204 | from .neptune import callbacks as neptune_cb
205 | from .raytune import callbacks as tune_cb
206 | from .tensorboard import callbacks as tensorboard_cb
207 | from .wb import callbacks as wb_cb
208 |
209 | for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb:
210 | for k, v in x.items():
211 | if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
212 | instance.callbacks[k].append(v) # callback[name].append(func)
213 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/clearml.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import re
4 |
5 | import matplotlib.image as mpimg
6 | import matplotlib.pyplot as plt
7 |
8 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
9 | from ultralytics.utils.torch_utils import model_info_for_loggers
10 |
11 | try:
12 | import clearml
13 | from clearml import Task
14 | from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
15 | from clearml.binding.matplotlib_bind import PatchedMatplotlib
16 |
17 | assert hasattr(clearml, '__version__') # verify package is not directory
18 | assert not TESTS_RUNNING # do not log pytest
19 | assert SETTINGS['clearml'] is True # verify integration is enabled
20 | except (ImportError, AssertionError):
21 | clearml = None
22 |
23 |
24 | def _log_debug_samples(files, title='Debug Samples') -> None:
25 | """
26 | Log files (images) as debug samples in the ClearML task.
27 |
28 | Args:
29 | files (list): A list of file paths in PosixPath format.
30 | title (str): A title that groups together images with the same values.
31 | """
32 | task = Task.current_task()
33 | if task:
34 | for f in files:
35 | if f.exists():
36 | it = re.search(r'_batch(\d+)', f.name)
37 | iteration = int(it.groups()[0]) if it else 0
38 | task.get_logger().report_image(title=title,
39 | series=f.name.replace(it.group(), ''),
40 | local_path=str(f),
41 | iteration=iteration)
42 |
43 |
44 | def _log_plot(title, plot_path) -> None:
45 | """
46 | Log an image as a plot in the plot section of ClearML.
47 |
48 | Args:
49 | title (str): The title of the plot.
50 | plot_path (str): The path to the saved image file.
51 | """
52 | img = mpimg.imread(plot_path)
53 | fig = plt.figure()
54 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
55 | ax.imshow(img)
56 |
57 | Task.current_task().get_logger().report_matplotlib_figure(title=title,
58 | series='',
59 | figure=fig,
60 | report_interactive=False)
61 |
62 |
63 | def on_pretrain_routine_start(trainer):
64 | """Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
65 | try:
66 | task = Task.current_task()
67 | if task:
68 | # Make sure the automatic pytorch and matplotlib bindings are disabled!
69 | # We are logging these plots and model files manually in the integration
70 | PatchPyTorchModelIO.update_current_task(None)
71 | PatchedMatplotlib.update_current_task(None)
72 | else:
73 | task = Task.init(project_name=trainer.args.project or 'YOLOv8',
74 | task_name=trainer.args.name,
75 | tags=['YOLOv8'],
76 | output_uri=True,
77 | reuse_last_task_id=False,
78 | auto_connect_frameworks={
79 | 'pytorch': False,
80 | 'matplotlib': False})
81 | LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
82 | 'please add clearml-init and connect your arguments before initializing YOLO.')
83 | task.connect(vars(trainer.args), name='General')
84 | except Exception as e:
85 | LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
86 |
87 |
88 | def on_train_epoch_end(trainer):
89 | task = Task.current_task()
90 |
91 | if task:
92 | """Logs debug samples for the first epoch of YOLO training."""
93 | if trainer.epoch == 1:
94 | _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
95 | """Report the current training progress."""
96 | for k, v in trainer.validator.metrics.results_dict.items():
97 | task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
98 |
99 |
100 | def on_fit_epoch_end(trainer):
101 | """Reports model information to logger at the end of an epoch."""
102 | task = Task.current_task()
103 | if task:
104 | # You should have access to the validation bboxes under jdict
105 | task.get_logger().report_scalar(title='Epoch Time',
106 | series='Epoch Time',
107 | value=trainer.epoch_time,
108 | iteration=trainer.epoch)
109 | if trainer.epoch == 0:
110 | for k, v in model_info_for_loggers(trainer).items():
111 | task.get_logger().report_single_value(k, v)
112 |
113 |
114 | def on_val_end(validator):
115 | """Logs validation results including labels and predictions."""
116 | if Task.current_task():
117 | # Log val_labels and val_pred
118 | _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
119 |
120 |
121 | def on_train_end(trainer):
122 | """Logs final model and its name on training completion."""
123 | task = Task.current_task()
124 | if task:
125 | # Log final results, CM matrix + PR plots
126 | files = [
127 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
128 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
129 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
130 | for f in files:
131 | _log_plot(title=f.stem, plot_path=f)
132 | # Report final metrics
133 | for k, v in trainer.validator.metrics.results_dict.items():
134 | task.get_logger().report_single_value(k, v)
135 | # Log the final model
136 | task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
137 |
138 |
139 | callbacks = {
140 | 'on_pretrain_routine_start': on_pretrain_routine_start,
141 | 'on_train_epoch_end': on_train_epoch_end,
142 | 'on_fit_epoch_end': on_fit_epoch_end,
143 | 'on_val_end': on_val_end,
144 | 'on_train_end': on_train_end} if clearml else {}
145 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/dvc.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import os
4 | import re
5 | from pathlib import Path
6 |
7 | import pkg_resources as pkg
8 |
9 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
10 | from ultralytics.utils.torch_utils import model_info_for_loggers
11 |
12 | try:
13 | from importlib.metadata import version
14 |
15 | import dvclive
16 |
17 | assert not TESTS_RUNNING # do not log pytest
18 | assert SETTINGS['dvc'] is True # verify integration is enabled
19 |
20 | ver = version('dvclive')
21 | if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
22 | LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
23 | dvclive = None # noqa: F811
24 | except (ImportError, AssertionError, TypeError):
25 | dvclive = None
26 |
27 | # DVCLive logger instance
28 | live = None
29 | _processed_plots = {}
30 |
31 | # `on_fit_epoch_end` is called on final validation (probably need to be fixed)
32 | # for now this is the way we distinguish final evaluation of the best model vs
33 | # last epoch validation
34 | _training_epoch = False
35 |
36 |
37 | def _log_images(path, prefix=''):
38 | if live:
39 | name = path.name
40 |
41 | # Group images by batch to enable sliders in UI
42 | if m := re.search(r'_batch(\d+)', name):
43 | ni = m.group(1)
44 | new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
45 | name = (Path(new_stem) / ni).with_suffix(path.suffix)
46 |
47 | live.log_image(os.path.join(prefix, name), path)
48 |
49 |
50 | def _log_plots(plots, prefix=''):
51 | for name, params in plots.items():
52 | timestamp = params['timestamp']
53 | if _processed_plots.get(name) != timestamp:
54 | _log_images(name, prefix)
55 | _processed_plots[name] = timestamp
56 |
57 |
58 | def _log_confusion_matrix(validator):
59 | targets = []
60 | preds = []
61 | matrix = validator.confusion_matrix.matrix
62 | names = list(validator.names.values())
63 | if validator.confusion_matrix.task == 'detect':
64 | names += ['background']
65 |
66 | for ti, pred in enumerate(matrix.T.astype(int)):
67 | for pi, num in enumerate(pred):
68 | targets.extend([names[ti]] * num)
69 | preds.extend([names[pi]] * num)
70 |
71 | live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
72 |
73 |
74 | def on_pretrain_routine_start(trainer):
75 | try:
76 | global live
77 | live = dvclive.Live(save_dvc_exp=True, cache_images=True)
78 | LOGGER.info(
79 | f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).'
80 | )
81 | except Exception as e:
82 | LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
83 |
84 |
85 | def on_pretrain_routine_end(trainer):
86 | _log_plots(trainer.plots, 'train')
87 |
88 |
89 | def on_train_start(trainer):
90 | if live:
91 | live.log_params(trainer.args)
92 |
93 |
94 | def on_train_epoch_start(trainer):
95 | global _training_epoch
96 | _training_epoch = True
97 |
98 |
99 | def on_fit_epoch_end(trainer):
100 | global _training_epoch
101 | if live and _training_epoch:
102 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
103 | for metric, value in all_metrics.items():
104 | live.log_metric(metric, value)
105 |
106 | if trainer.epoch == 0:
107 | for metric, value in model_info_for_loggers(trainer).items():
108 | live.log_metric(metric, value, plot=False)
109 |
110 | _log_plots(trainer.plots, 'train')
111 | _log_plots(trainer.validator.plots, 'val')
112 |
113 | live.next_step()
114 | _training_epoch = False
115 |
116 |
117 | def on_train_end(trainer):
118 | if live:
119 | # At the end log the best metrics. It runs validator on the best model internally.
120 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
121 | for metric, value in all_metrics.items():
122 | live.log_metric(metric, value, plot=False)
123 |
124 | _log_plots(trainer.plots, 'val')
125 | _log_plots(trainer.validator.plots, 'val')
126 | _log_confusion_matrix(trainer.validator)
127 |
128 | if trainer.best.exists():
129 | live.log_artifact(trainer.best, copy=True, type='model')
130 |
131 | live.end()
132 |
133 |
134 | callbacks = {
135 | 'on_pretrain_routine_start': on_pretrain_routine_start,
136 | 'on_pretrain_routine_end': on_pretrain_routine_end,
137 | 'on_train_start': on_train_start,
138 | 'on_train_epoch_start': on_train_epoch_start,
139 | 'on_fit_epoch_end': on_fit_epoch_end,
140 | 'on_train_end': on_train_end} if dvclive else {}
141 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/hub.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import json
4 | from time import time
5 |
6 | from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
7 | from ultralytics.utils import LOGGER, SETTINGS
8 | from ultralytics.utils.torch_utils import model_info_for_loggers
9 |
10 |
11 | def on_pretrain_routine_end(trainer):
12 | """Logs info before starting timer for upload rate limit."""
13 | session = getattr(trainer, 'hub_session', None)
14 | if session:
15 | # Start timer for upload rate limit
16 | LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
17 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
18 |
19 |
20 | def on_fit_epoch_end(trainer):
21 | """Uploads training progress metrics at the end of each epoch."""
22 | session = getattr(trainer, 'hub_session', None)
23 | if session:
24 | # Upload metrics after val end
25 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
26 | if trainer.epoch == 0:
27 | all_plots = {**all_plots, **model_info_for_loggers(trainer)}
28 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
29 | if time() - session.timers['metrics'] > session.rate_limits['metrics']:
30 | session.upload_metrics()
31 | session.timers['metrics'] = time() # reset timer
32 | session.metrics_queue = {} # reset queue
33 |
34 |
35 | def on_model_save(trainer):
36 | """Saves checkpoints to Ultralytics HUB with rate limiting."""
37 | session = getattr(trainer, 'hub_session', None)
38 | if session:
39 | # Upload checkpoints with rate limiting
40 | is_best = trainer.best_fitness == trainer.fitness
41 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
42 | LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
43 | session.upload_model(trainer.epoch, trainer.last, is_best)
44 | session.timers['ckpt'] = time() # reset timer
45 |
46 |
47 | def on_train_end(trainer):
48 | """Upload final model and metrics to Ultralytics HUB at the end of training."""
49 | session = getattr(trainer, 'hub_session', None)
50 | if session:
51 | # Upload final model and metrics with exponential standoff
52 | LOGGER.info(f'{PREFIX}Syncing final model...')
53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
54 | session.alive = False # stop heartbeats
55 | LOGGER.info(f'{PREFIX}Done ✅\n'
56 | f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
57 |
58 |
59 | def on_train_start(trainer):
60 | """Run events on train start."""
61 | events(trainer.args)
62 |
63 |
64 | def on_val_start(validator):
65 | """Runs events on validation start."""
66 | events(validator.args)
67 |
68 |
69 | def on_predict_start(predictor):
70 | """Run events on predict start."""
71 | events(predictor.args)
72 |
73 |
74 | def on_export_start(exporter):
75 | """Run events on export start."""
76 | events(exporter.args)
77 |
78 |
79 | callbacks = {
80 | 'on_pretrain_routine_end': on_pretrain_routine_end,
81 | 'on_fit_epoch_end': on_fit_epoch_end,
82 | 'on_model_save': on_model_save,
83 | 'on_train_end': on_train_end,
84 | 'on_train_start': on_train_start,
85 | 'on_val_start': on_val_start,
86 | 'on_predict_start': on_predict_start,
87 | 'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled
88 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/mlflow.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import os
4 | import re
5 | from pathlib import Path
6 |
7 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
8 |
9 | try:
10 | import mlflow
11 |
12 | assert not TESTS_RUNNING # do not log pytest
13 | assert hasattr(mlflow, '__version__') # verify package is not directory
14 | assert SETTINGS['mlflow'] is True # verify integration is enabled
15 | except (ImportError, AssertionError):
16 | mlflow = None
17 |
18 |
19 | def on_pretrain_routine_end(trainer):
20 | """Logs training parameters to MLflow."""
21 | global mlflow, run, experiment_name
22 |
23 | if os.environ.get('MLFLOW_TRACKING_URI') is None:
24 | mlflow = None
25 |
26 | if mlflow:
27 | mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
28 | mlflow.set_tracking_uri(mlflow_location)
29 |
30 | experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
31 | run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
32 | experiment = mlflow.get_experiment_by_name(experiment_name)
33 | if experiment is None:
34 | mlflow.create_experiment(experiment_name)
35 | mlflow.set_experiment(experiment_name)
36 |
37 | prefix = colorstr('MLFlow: ')
38 | try:
39 | run, active_run = mlflow, mlflow.active_run()
40 | if not active_run:
41 | active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
42 | LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}')
43 | run.log_params(vars(trainer.model.args))
44 | except Exception as err:
45 | LOGGER.error(f'{prefix}Failing init - {repr(err)}')
46 | LOGGER.warning(f'{prefix}Continuing without Mlflow')
47 |
48 |
49 | def on_fit_epoch_end(trainer):
50 | """Logs training metrics to Mlflow."""
51 | if mlflow:
52 | metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
53 | run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
54 |
55 |
56 | def on_train_end(trainer):
57 | """Called at end of train loop to log model artifact info."""
58 | if mlflow:
59 | root_dir = Path(__file__).resolve().parents[3]
60 | run.log_artifact(trainer.last)
61 | run.log_artifact(trainer.best)
62 | run.pyfunc.log_model(artifact_path=experiment_name,
63 | code_path=[str(root_dir)],
64 | artifacts={'model_path': str(trainer.save_dir)},
65 | python_model=run.pyfunc.PythonModel())
66 |
67 |
68 | callbacks = {
69 | 'on_pretrain_routine_end': on_pretrain_routine_end,
70 | 'on_fit_epoch_end': on_fit_epoch_end,
71 | 'on_train_end': on_train_end} if mlflow else {}
72 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/neptune.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import matplotlib.image as mpimg
4 | import matplotlib.pyplot as plt
5 |
6 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
7 | from ultralytics.utils.torch_utils import model_info_for_loggers
8 |
9 | try:
10 | import neptune
11 | from neptune.types import File
12 |
13 | assert not TESTS_RUNNING # do not log pytest
14 | assert hasattr(neptune, '__version__')
15 | assert SETTINGS['neptune'] is True # verify integration is enabled
16 | except (ImportError, AssertionError):
17 | neptune = None
18 |
19 | run = None # NeptuneAI experiment logger instance
20 |
21 |
22 | def _log_scalars(scalars, step=0):
23 | """Log scalars to the NeptuneAI experiment logger."""
24 | if run:
25 | for k, v in scalars.items():
26 | run[k].append(value=v, step=step)
27 |
28 |
29 | def _log_images(imgs_dict, group=''):
30 | """Log scalars to the NeptuneAI experiment logger."""
31 | if run:
32 | for k, v in imgs_dict.items():
33 | run[f'{group}/{k}'].upload(File(v))
34 |
35 |
36 | def _log_plot(title, plot_path):
37 | """Log plots to the NeptuneAI experiment logger."""
38 | """
39 | Log image as plot in the plot section of NeptuneAI
40 |
41 | arguments:
42 | title (str) Title of the plot
43 | plot_path (PosixPath or str) Path to the saved image file
44 | """
45 | img = mpimg.imread(plot_path)
46 | fig = plt.figure()
47 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
48 | ax.imshow(img)
49 | run[f'Plots/{title}'].upload(fig)
50 |
51 |
52 | def on_pretrain_routine_start(trainer):
53 | """Callback function called before the training routine starts."""
54 | try:
55 | global run
56 | run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
57 | run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
58 | except Exception as e:
59 | LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
60 |
61 |
62 | def on_train_epoch_end(trainer):
63 | """Callback function called at end of each training epoch."""
64 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
65 | _log_scalars(trainer.lr, trainer.epoch + 1)
66 | if trainer.epoch == 1:
67 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
68 |
69 |
70 | def on_fit_epoch_end(trainer):
71 | """Callback function called at end of each fit (train+val) epoch."""
72 | if run and trainer.epoch == 0:
73 | run['Configuration/Model'] = model_info_for_loggers(trainer)
74 | _log_scalars(trainer.metrics, trainer.epoch + 1)
75 |
76 |
77 | def on_val_end(validator):
78 | """Callback function called at end of each validation."""
79 | if run:
80 | # Log val_labels and val_pred
81 | _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
82 |
83 |
84 | def on_train_end(trainer):
85 | """Callback function called at end of training."""
86 | if run:
87 | # Log final results, CM matrix + PR plots
88 | files = [
89 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
90 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
91 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
92 | for f in files:
93 | _log_plot(title=f.stem, plot_path=f)
94 | # Log the final model
95 | run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
96 | trainer.best)))
97 |
98 |
99 | callbacks = {
100 | 'on_pretrain_routine_start': on_pretrain_routine_start,
101 | 'on_train_epoch_end': on_train_epoch_end,
102 | 'on_fit_epoch_end': on_fit_epoch_end,
103 | 'on_val_end': on_val_end,
104 | 'on_train_end': on_train_end} if neptune else {}
105 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/raytune.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.utils import SETTINGS
4 |
5 | try:
6 | import ray
7 | from ray import tune
8 | from ray.air import session
9 |
10 | assert SETTINGS['raytune'] is True # verify integration is enabled
11 | except (ImportError, AssertionError):
12 | tune = None
13 |
14 |
15 | def on_fit_epoch_end(trainer):
16 | """Sends training metrics to Ray Tune at end of each epoch."""
17 | if ray.tune.is_session_enabled():
18 | metrics = trainer.metrics
19 | metrics['epoch'] = trainer.epoch
20 | session.report(metrics)
21 |
22 |
23 | callbacks = {
24 | 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
25 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/tensorboard.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
4 |
5 | try:
6 | from torch.utils.tensorboard import SummaryWriter
7 |
8 | assert not TESTS_RUNNING # do not log pytest
9 | assert SETTINGS['tensorboard'] is True # verify integration is enabled
10 |
11 | # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
12 | except (ImportError, AssertionError, TypeError):
13 | SummaryWriter = None
14 |
15 | writer = None # TensorBoard SummaryWriter instance
16 |
17 |
18 | def _log_scalars(scalars, step=0):
19 | """Logs scalar values to TensorBoard."""
20 | if writer:
21 | for k, v in scalars.items():
22 | writer.add_scalar(k, v, step)
23 |
24 |
25 | def on_pretrain_routine_start(trainer):
26 | """Initialize TensorBoard logging with SummaryWriter."""
27 | if SummaryWriter:
28 | try:
29 | global writer
30 | writer = SummaryWriter(str(trainer.save_dir))
31 | prefix = colorstr('TensorBoard: ')
32 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
33 | except Exception as e:
34 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
35 |
36 |
37 | def on_batch_end(trainer):
38 | """Logs scalar statistics at the end of a training batch."""
39 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
40 |
41 |
42 | def on_fit_epoch_end(trainer):
43 | """Logs epoch metrics at end of training epoch."""
44 | _log_scalars(trainer.metrics, trainer.epoch + 1)
45 |
46 |
47 | callbacks = {
48 | 'on_pretrain_routine_start': on_pretrain_routine_start,
49 | 'on_fit_epoch_end': on_fit_epoch_end,
50 | 'on_batch_end': on_batch_end}
51 |
--------------------------------------------------------------------------------
/ultralytics/utils/callbacks/wb.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.utils import SETTINGS, TESTS_RUNNING
4 | from ultralytics.utils.torch_utils import model_info_for_loggers
5 |
6 | try:
7 | import wandb as wb
8 |
9 | assert hasattr(wb, '__version__')
10 | assert not TESTS_RUNNING # do not log pytest
11 | assert SETTINGS['wandb'] is True # verify integration is enabled
12 | except (ImportError, AssertionError):
13 | wb = None
14 |
15 | _processed_plots = {}
16 |
17 |
18 | def _log_plots(plots, step):
19 | for name, params in plots.items():
20 | timestamp = params['timestamp']
21 | if _processed_plots.get(name) != timestamp:
22 | wb.run.log({name.stem: wb.Image(str(name))}, step=step)
23 | _processed_plots[name] = timestamp
24 |
25 |
26 | def on_pretrain_routine_start(trainer):
27 | """Initiate and start project if module is present."""
28 | wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
29 |
30 |
31 | def on_fit_epoch_end(trainer):
32 | """Logs training metrics and model information at the end of an epoch."""
33 | wb.run.log(trainer.metrics, step=trainer.epoch + 1)
34 | _log_plots(trainer.plots, step=trainer.epoch + 1)
35 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
36 | if trainer.epoch == 0:
37 | wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
38 |
39 |
40 | def on_train_epoch_end(trainer):
41 | """Log metrics and save images at the end of each training epoch."""
42 | wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
43 | wb.run.log(trainer.lr, step=trainer.epoch + 1)
44 | if trainer.epoch == 1:
45 | _log_plots(trainer.plots, step=trainer.epoch + 1)
46 |
47 |
48 | def on_train_end(trainer):
49 | """Save the best model as an artifact at end of training."""
50 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
51 | _log_plots(trainer.plots, step=trainer.epoch + 1)
52 | art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
53 | if trainer.best.exists():
54 | art.add_file(trainer.best)
55 | wb.run.log_artifact(art, aliases=['best'])
56 |
57 |
58 | callbacks = {
59 | 'on_pretrain_routine_start': on_pretrain_routine_start,
60 | 'on_train_epoch_end': on_train_epoch_end,
61 | 'on_fit_epoch_end': on_fit_epoch_end,
62 | 'on_train_end': on_train_end} if wb else {}
63 |
--------------------------------------------------------------------------------
/ultralytics/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import os
4 | import re
5 | import shutil
6 | import socket
7 | import sys
8 | import tempfile
9 | from pathlib import Path
10 |
11 | from . import USER_CONFIG_DIR
12 | from .torch_utils import TORCH_1_9
13 |
14 |
15 | def find_free_network_port() -> int:
16 | """Finds a free port on localhost.
17 |
18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the
19 | `MASTER_PORT` environment variable.
20 | """
21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
22 | s.bind(('127.0.0.1', 0))
23 | return s.getsockname()[1] # port
24 |
25 |
26 | def generate_ddp_file(trainer):
27 | """Generates a DDP file and returns its file name."""
28 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
29 |
30 | content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
31 | from {module} import {name}
32 | from ultralytics.utils import DEFAULT_CFG_DICT
33 |
34 | cfg = DEFAULT_CFG_DICT.copy()
35 | cfg.update(save_dir='') # handle the extra key 'save_dir'
36 | trainer = {name}(cfg=cfg, overrides=overrides)
37 | trainer.train()'''
38 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
39 | with tempfile.NamedTemporaryFile(prefix='_temp_',
40 | suffix=f'{id(trainer)}.py',
41 | mode='w+',
42 | encoding='utf-8',
43 | dir=USER_CONFIG_DIR / 'DDP',
44 | delete=False) as file:
45 | file.write(content)
46 | return file.name
47 |
48 |
49 | def generate_ddp_command(world_size, trainer):
50 | """Generates and returns command for distributed training."""
51 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
52 | if not trainer.resume:
53 | shutil.rmtree(trainer.save_dir) # remove the save_dir
54 | file = str(Path(sys.argv[0]).resolve())
55 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
56 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
57 | file = generate_ddp_file(trainer)
58 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
59 | port = find_free_network_port()
60 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
61 | return cmd, file
62 |
63 |
64 | def ddp_cleanup(trainer, file):
65 | """Delete temp file if created."""
66 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file
67 | os.remove(file)
68 |
--------------------------------------------------------------------------------
/ultralytics/utils/errors.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.utils import emojis
4 |
5 |
6 | class HUBModelError(Exception):
7 |
8 | def __init__(self, message='Model not found. Please check model URL and try again.'):
9 | """Create an exception for when a model is not found."""
10 | super().__init__(emojis(message))
11 |
--------------------------------------------------------------------------------
/ultralytics/utils/files.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import contextlib
4 | import glob
5 | import os
6 | import shutil
7 | import tempfile
8 | from contextlib import contextmanager
9 | from datetime import datetime
10 | from pathlib import Path
11 |
12 |
13 | class WorkingDirectory(contextlib.ContextDecorator):
14 | """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
15 |
16 | def __init__(self, new_dir):
17 | """Sets the working directory to 'new_dir' upon instantiation."""
18 | self.dir = new_dir # new dir
19 | self.cwd = Path.cwd().resolve() # current dir
20 |
21 | def __enter__(self):
22 | """Changes the current directory to the specified directory."""
23 | os.chdir(self.dir)
24 |
25 | def __exit__(self, exc_type, exc_val, exc_tb): # noqa
26 | """Restore the current working directory on context exit."""
27 | os.chdir(self.cwd)
28 |
29 |
30 | @contextmanager
31 | def spaces_in_path(path):
32 | """
33 | Context manager to handle paths with spaces in their names.
34 | If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path,
35 | executes the context code block, then copies the file/directory back to its original location.
36 |
37 | Args:
38 | path (str | Path): The original path.
39 |
40 | Yields:
41 | (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
42 |
43 | Example:
44 | ```python
45 | with spaces_in_path('/path/with spaces') as new_path:
46 | # your code here
47 | ```
48 | """
49 |
50 | # If path has spaces, replace them with underscores
51 | if ' ' in str(path):
52 | string = isinstance(path, str) # input type
53 | path = Path(path)
54 |
55 | # Create a temporary directory and construct the new path
56 | with tempfile.TemporaryDirectory() as tmp_dir:
57 | tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
58 |
59 | # Copy file/directory
60 | if path.is_dir():
61 | # tmp_path.mkdir(parents=True, exist_ok=True)
62 | shutil.copytree(path, tmp_path)
63 | elif path.is_file():
64 | tmp_path.parent.mkdir(parents=True, exist_ok=True)
65 | shutil.copy2(path, tmp_path)
66 |
67 | try:
68 | # Yield the temporary path
69 | yield str(tmp_path) if string else tmp_path
70 |
71 | finally:
72 | # Copy file/directory back
73 | if tmp_path.is_dir():
74 | shutil.copytree(tmp_path, path, dirs_exist_ok=True)
75 | elif tmp_path.is_file():
76 | shutil.copy2(tmp_path, path) # Copy back the file
77 |
78 | else:
79 | # If there are no spaces, just yield the original path
80 | yield path
81 |
82 |
83 | def increment_path(path, exist_ok=False, sep='', mkdir=False):
84 | """
85 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
86 |
87 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
88 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
89 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
90 | directory if it does not already exist.
91 |
92 | Args:
93 | path (str, pathlib.Path): Path to increment.
94 | exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
95 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
96 | mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
97 |
98 | Returns:
99 | (pathlib.Path): Incremented path.
100 | """
101 | path = Path(path) # os-agnostic
102 | if path.exists() and not exist_ok:
103 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
104 |
105 | # Method 1
106 | for n in range(2, 9999):
107 | p = f'{path}{sep}{n}{suffix}' # increment path
108 | if not os.path.exists(p): #
109 | break
110 | path = Path(p)
111 |
112 | if mkdir:
113 | path.mkdir(parents=True, exist_ok=True) # make directory
114 |
115 | return path
116 |
117 |
118 | def file_age(path=__file__):
119 | """Return days since last file update."""
120 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
121 | return dt.days # + dt.seconds / 86400 # fractional days
122 |
123 |
124 | def file_date(path=__file__):
125 | """Return human-readable file modification date, i.e. '2021-3-26'."""
126 | t = datetime.fromtimestamp(Path(path).stat().st_mtime)
127 | return f'{t.year}-{t.month}-{t.day}'
128 |
129 |
130 | def file_size(path):
131 | """Return file/dir size (MB)."""
132 | if isinstance(path, (str, Path)):
133 | mb = 1 << 20 # bytes to MiB (1024 ** 2)
134 | path = Path(path)
135 | if path.is_file():
136 | return path.stat().st_size / mb
137 | elif path.is_dir():
138 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
139 | return 0.0
140 |
141 |
142 | def get_latest_run(search_dir='.'):
143 | """Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
144 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
145 | return max(last_list, key=os.path.getctime) if last_list else ''
146 |
147 |
148 | def make_dirs(dir='new_dir/'):
149 | """Create directories."""
150 | dir = Path(dir)
151 | if dir.exists():
152 | shutil.rmtree(dir) # delete dir
153 | for p in dir, dir / 'labels', dir / 'images':
154 | p.mkdir(parents=True, exist_ok=True) # make dir
155 | return dir
156 |
--------------------------------------------------------------------------------
/ultralytics/utils/patches.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Monkey patches to update/extend functionality of existing functions
4 | """
5 |
6 | from pathlib import Path
7 |
8 | import cv2
9 | import numpy as np
10 | import torch
11 |
12 | # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
13 | _imshow = cv2.imshow # copy to avoid recursion errors
14 |
15 |
16 | def imread(filename, flags=cv2.IMREAD_COLOR):
17 | return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
18 |
19 |
20 | def imwrite(filename, img):
21 | try:
22 | cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
23 | return True
24 | except Exception:
25 | return False
26 |
27 |
28 | def imshow(path, im):
29 | _imshow(path.encode('unicode_escape').decode(), im)
30 |
31 |
32 | # PyTorch functions ----------------------------------------------------------------------------------------------------
33 | _torch_save = torch.save # copy to avoid recursion errors
34 |
35 |
36 | def torch_save(*args, **kwargs):
37 | """Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
38 | try:
39 | import dill as pickle
40 | except ImportError:
41 | import pickle
42 |
43 | if 'pickle_module' not in kwargs:
44 | kwargs['pickle_module'] = pickle
45 | return _torch_save(*args, **kwargs)
46 |
--------------------------------------------------------------------------------
/ultralytics/utils/tuner.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.cfg import TASK2DATA, TASK2METRIC
4 | from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
5 |
6 |
7 | def run_ray_tune(model,
8 | space: dict = None,
9 | grace_period: int = 10,
10 | gpu_per_trial: int = None,
11 | max_samples: int = 10,
12 | **train_args):
13 | """
14 | Runs hyperparameter tuning using Ray Tune.
15 |
16 | Args:
17 | model (YOLO): Model to run the tuner on.
18 | space (dict, optional): The hyperparameter search space. Defaults to None.
19 | grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
20 | gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
21 | max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
22 | train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
23 |
24 | Returns:
25 | (dict): A dictionary containing the results of the hyperparameter search.
26 |
27 | Raises:
28 | ModuleNotFoundError: If Ray Tune is not installed.
29 | """
30 | if train_args is None:
31 | train_args = {}
32 |
33 | try:
34 | from ray import tune
35 | from ray.air import RunConfig
36 | from ray.air.integrations.wandb import WandbLoggerCallback
37 | from ray.tune.schedulers import ASHAScheduler
38 | except ImportError:
39 | raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"')
40 |
41 | try:
42 | import wandb
43 |
44 | assert hasattr(wandb, '__version__')
45 | except (ImportError, AssertionError):
46 | wandb = False
47 |
48 | default_space = {
49 | # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
50 | 'lr0': tune.uniform(1e-5, 1e-1),
51 | 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
52 | 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
53 | 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
54 | 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
55 | 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
56 | 'box': tune.uniform(0.02, 0.2), # box loss gain
57 | 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
58 | 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
59 | 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
60 | 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
61 | 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
62 | 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
63 | 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
64 | 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
65 | 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
66 | 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
67 | 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
68 | 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
69 | 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
70 | 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
71 |
72 | def _tune(config):
73 | """
74 | Trains the YOLO model with the specified hyperparameters and additional arguments.
75 |
76 | Args:
77 | config (dict): A dictionary of hyperparameters to use for training.
78 |
79 | Returns:
80 | None.
81 | """
82 | model._reset_callbacks()
83 | config.update(train_args)
84 | model.train(**config)
85 |
86 | # Get search space
87 | if not space:
88 | space = default_space
89 | LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
90 |
91 | # Get dataset
92 | data = train_args.get('data', TASK2DATA[model.task])
93 | space['data'] = data
94 | if 'data' not in train_args:
95 | LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
96 |
97 | # Define the trainable function with allocated resources
98 | trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
99 |
100 | # Define the ASHA scheduler for hyperparameter search
101 | asha_scheduler = ASHAScheduler(time_attr='epoch',
102 | metric=TASK2METRIC[model.task],
103 | mode='max',
104 | max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
105 | grace_period=grace_period,
106 | reduction_factor=3)
107 |
108 | # Define the callbacks for the hyperparameter search
109 | tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
110 |
111 | # Create the Ray Tune hyperparameter search tuner
112 | tuner = tune.Tuner(trainable_with_resources,
113 | param_space=space,
114 | tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
115 | run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
116 |
117 | # Run the hyperparameter search
118 | tuner.fit()
119 |
120 | # Return the results of the hyperparameter search
121 | return tuner.get_results()
122 |
--------------------------------------------------------------------------------
/ultralytics/yolo/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from . import v8
4 |
5 | __all__ = 'v8', # tuple or list
6 |
--------------------------------------------------------------------------------
/ultralytics/yolo/cfg/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | from ultralytics.utils import LOGGER
5 |
6 | # Set modules in sys.modules under their old name
7 | sys.modules['ultralytics.yolo.cfg'] = importlib.import_module('ultralytics.cfg')
8 |
9 | LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.cfg' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
10 | "Please use 'ultralytics.cfg' instead.")
11 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | from ultralytics.utils import LOGGER
5 |
6 | # Set modules in sys.modules under their old name
7 | sys.modules['ultralytics.yolo.data'] = importlib.import_module('ultralytics.data')
8 | # This is for updating old cls models, or the way in following warning won't work.
9 | sys.modules['ultralytics.yolo.data.augment'] = importlib.import_module('ultralytics.data.augment')
10 |
11 | DATA_WARNING = """WARNING ⚠️ 'ultralytics.yolo.data' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.data' instead.
12 | Note this warning may be related to loading older models. You can update your model to current structure with:
13 | import torch
14 | ckpt = torch.load("model.pt") # applies to both official and custom models
15 | torch.save(ckpt, "updated-model.pt")
16 | """
17 | LOGGER.warning(DATA_WARNING)
18 |
--------------------------------------------------------------------------------
/ultralytics/yolo/engine/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | from ultralytics.utils import LOGGER
5 |
6 | # Set modules in sys.modules under their old name
7 | sys.modules['ultralytics.yolo.engine'] = importlib.import_module('ultralytics.engine')
8 |
9 | LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.engine' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
10 | "Please use 'ultralytics.engine' instead.")
11 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | from ultralytics.utils import LOGGER
5 |
6 | # Set modules in sys.modules under their old name
7 | sys.modules['ultralytics.yolo.utils'] = importlib.import_module('ultralytics.utils')
8 |
9 | UTILS_WARNING = """WARNING ⚠️ 'ultralytics.yolo.utils' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.utils' instead.
10 | Note this warning may be related to loading older models. You can update your model to current structure with:
11 | import torch
12 | ckpt = torch.load("model.pt") # applies to both official and custom models
13 | torch.save(ckpt, "updated-model.pt")
14 | """
15 | LOGGER.warning(UTILS_WARNING)
16 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | from ultralytics.utils import LOGGER
5 |
6 | # Set modules in sys.modules under their old name
7 | sys.modules['ultralytics.yolo.v8'] = importlib.import_module('ultralytics.models.yolo')
8 |
9 | LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.v8' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
10 | "Please use 'ultralytics.models.yolo' instead.")
11 |
--------------------------------------------------------------------------------
/y_prune/prune1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ultralytics import YOLO
4 | from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
5 | import numpy as np
6 |
7 | def prune_conv(conv1,conv2,threshold):
8 | gamma=conv1.bn.weight.data.detach()
9 | beta=conv1.bn.bias.detach()
10 | keep_idxs=[]
11 | local_threshold = threshold
12 | while(len(keep_idxs)<8):
13 | keep_idx = torch.where(gamma.abs()>=local_threshold)[0]
14 | keep_idx = torch.ceil(torch.tensor(len(keep_idx)/8))*8 #为保证最后的通道式是8的倍数
15 | new_threashold = torch.sort(gamma.abs(),descending=True)[0][int(keep_idx-1)]
16 | keep_idxs = torch.where(gamma.abs()>=new_threashold)[0] #得到剪枝后通道的index
17 | local_threshold = new_threashold*0.5
18 | n = len(keep_idxs)
19 | print("prune rate for this layer: {%.2f} %",n/len(gamma) * 100)
20 | conv1.bn.weight.data = gamma[keep_idxs] #对conv1的bn层进行调整
21 | conv1.bn.bias.data = beta[keep_idxs]
22 | conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
23 | conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
24 | conv1.bn.num_features = n
25 | conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs] #对conv1的conv层进行调整
26 | conv1.conv.out_channels = n
27 | if conv1.conv.bias is not None:
28 | conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
29 | if not isinstance(conv2,list): #对conv2层进行处理,转换为list,此处统一为List是因为后续会传入list(1对多)
30 | conv2 = [conv2]
31 | for item in conv2:
32 | if item is not None:
33 | if isinstance(item,Conv): #找到conv2中的conv层并对其输入通道进行调整
34 | conv = item.conv
35 | else:
36 | conv = item
37 | conv.in_channels = n #conv2层的输入通道数量调整到与剪枝后conv1层的输出通道一致
38 | conv.weight.data = conv.weight.data[:,keep_idxs]
39 |
40 | def prune(m1,m2,threshold):
41 | if isinstance(m1,C2f):
42 | m1 = m1.cv2
43 | if not isinstance(m2,list):
44 | m2 = [m2]
45 | for i, item in enumerate(m2):
46 | if isinstance(item,C2f) or isinstance(item,SPPF):
47 | m2[i] = item.cv1
48 | prune_conv(m1,m2,threshold)
49 |
50 | if __name__ == "__main__":
51 | yolo = YOLO("weights/best.pt") # build a new model from scratch
52 | model = yolo.model
53 | ws = []
54 | bs = []
55 | for name, m in model.named_modules():
56 | if isinstance(m, nn.BatchNorm2d):
57 | w = m.weight.abs().detach()
58 | b = m.bias.abs().detach()
59 | ws.append(w)
60 | bs.append(b)
61 | print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
62 | factor = 0.8 #剪枝率暂时设置为0.8,推荐选用小的剪枝率进行多次剪枝操作
63 | ws = torch.cat(ws)
64 | threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
65 |
66 | seq = model.model
67 | #针对head部分,low level的先不要减,可能包含有重要的信息
68 | for i in range(3, 9):
69 | if i in [6, 4, 9]:
70 | continue
71 | prune(seq[i], seq[i+1],threshold)
72 | #对15,18,21层处理,15和18不仅和detect层相连,同时下个layer中也有conv层,但是21层是和Detect层直连
73 | for n,i in enumerate([15,18,21]):
74 | if(i!=21):
75 | prune(seq[i],[seq[i+1],seq[-1].cv2[n][0],seq[-1].cv3[n][0]],threshold) #C2f-conv,C2f-Detect
76 | prune_conv(seq[-1].cv2[n][0],seq[-1].cv2[n][1],threshold) #Detect.cv2
77 | prune_conv(seq[-1].cv2[n][1],seq[-1].cv2[n][2],threshold) #Detect.cv2
78 | prune_conv(seq[-1].cv3[n][0],seq[-1].cv3[n][1],threshold) #Detect.cv3
79 | prune_conv(seq[-1].cv3[n][1],seq[-1].cv3[n][2],threshold) #Detect.cv3
80 | #遍历模型并对bottleneck内的剪枝
81 | for name, m in model.named_modules():
82 | if isinstance(m, Bottleneck):
83 | prune_conv(m.cv1, m.cv2,threshold)
84 | for name, p in yolo.model.named_parameters():
85 | p.requires_grad = True
86 | yolo.train(data="ultralytics/cfg/datasets/VOC.yaml", epochs=300,workers=8,batch=32)
87 | print("done")
88 |
--------------------------------------------------------------------------------
/y_prune/prune2.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YellenYeason/yolov8-prune/b5bf7381ae3b88b848f2b525507b919c8a97f004/y_prune/prune2.py
--------------------------------------------------------------------------------
/y_prune/y_train.py:
--------------------------------------------------------------------------------
1 | from ultralytics import YOLO # build a new model from scratch
2 | model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
3 | model.train(data="ultralytics/cfg/datasets/VOC.yaml", epochs=3,batch=32,workers=8)
--------------------------------------------------------------------------------