├── .github ├── CODEOWNERS ├── dependabot.yml └── workflows │ ├── publish-docs.yml │ ├── publish-pre-release.yml │ ├── publish-release.yml │ ├── test-doc.yml │ └── uv-test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── assets │ ├── logo-trackers-black.svg │ ├── logo-trackers-violet.svg │ └── logo-trackers-white.svg ├── index.md ├── overrides │ ├── partials │ │ └── comments.html │ └── stylesheets │ │ └── style.css └── trackers │ └── core │ ├── deepsort │ └── tracker.md │ ├── reid │ └── reid.md │ └── sort │ └── tracker.md ├── mkdocs.yml ├── mypy.ini ├── pyproject.toml ├── test └── core │ ├── __init__.py │ └── reid │ ├── __init__.py │ └── dataset │ ├── __init__.py │ ├── test_base.py │ └── test_market_1501.py ├── tox.ini ├── trackers ├── __init__.py ├── core │ ├── __init__.py │ ├── base.py │ ├── deepsort │ │ ├── __init__.py │ │ ├── kalman_box_tracker.py │ │ └── tracker.py │ ├── reid │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── market_1501.py │ │ │ └── utils.py │ │ ├── metrics.py │ │ └── model.py │ └── sort │ │ ├── __init__.py │ │ ├── kalman_box_tracker.py │ │ └── tracker.py ├── log.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── downloader.py │ ├── sort_utils.py │ └── torch_utils.py └── uv.lock /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in 2 | # the repo. They will be requested for review when someone 3 | # opens a pull request. 4 | * @soumik12345 @SkalskiP @onuralpszr 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | commit-message: 9 | prefix: ⬆️ 10 | target-branch: "main" 11 | # Python 12 | - package-ecosystem: "uv" 13 | directory: "/" 14 | schedule: 15 | interval: "daily" 16 | commit-message: 17 | prefix: ⬆️ 18 | target-branch: "main" 19 | -------------------------------------------------------------------------------- /.github/workflows/publish-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | release: 9 | types: [published] 10 | 11 | # Ensure only one concurrent deployment 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.event_name == 'push' && github.ref}} 14 | cancel-in-progress: true 15 | 16 | # Restrict permissions by default 17 | permissions: 18 | contents: write # Required for committing to gh-pages 19 | pages: write # Required for deploying to Pages 20 | pull-requests: write # Required for PR comments 21 | 22 | jobs: 23 | deploy: 24 | name: Publish Docs 25 | runs-on: ubuntu-latest 26 | timeout-minutes: 10 27 | strategy: 28 | matrix: 29 | python-version: ["3.10"] 30 | steps: 31 | - name: 📥 Checkout the repository 32 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 33 | with: 34 | fetch-depth: 0 35 | 36 | - name: 🐍 Install uv and set Python ${{ matrix.python-version }} 37 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v6.1.0 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | activate-environment: true 41 | 42 | - name: 🔑 Create GitHub App token (mkdocs) 43 | id: mkdocs_token 44 | uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 45 | with: 46 | app-id: ${{ secrets.MKDOCS_APP_ID }} 47 | private-key: ${{ secrets.MKDOCS_PEM }} 48 | owner: roboflow 49 | repositories: mkdocs-material-insiders 50 | 51 | - name: 🏗️ Install dependencies 52 | run: | 53 | uv pip install -r pyproject.toml --group docs 54 | # Install mkdocs-material-insiders using the GitHub App token 55 | uv pip install "git+https://roboflow:${{ steps.mkdocs_token.outputs.token }}@github.com/roboflow/mkdocs-material-insiders.git@9.5.49-insiders-4.53.14#egg=mkdocs-material[imaging]" 56 | 57 | - name: ⚙️ Configure git for github-actions 58 | run: | 59 | git config --global user.name "github-actions[bot]" 60 | git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" 61 | 62 | - name: 🚀 Deploy Development Docs 63 | if: (github.event_name == 'push' && github.ref == 'refs/heads/main') || github.event_name == 'workflow_dispatch' 64 | run: | 65 | MKDOCS_GIT_COMMITTERS_APIKEY=${{ secrets.GITHUB_TOKEN }} uv run mike deploy --push develop 66 | 67 | - name: 🚀 Deploy Release Docs 68 | if: github.event_name == 'release' && github.event.action == 'published' 69 | run: | 70 | latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`) 71 | MKDOCS_GIT_COMMITTERS_APIKEY=${{ secrets.GITHUB_TOKEN }} uv run mike deploy --push --update-aliases $latest_tag latest 72 | -------------------------------------------------------------------------------- /.github/workflows/publish-pre-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Trackers Pre-Releases to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+[0-9]+a[0-9]" 7 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+[0-9]+b[0-9]" 8 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+[0-9]+rc[0-9]" 9 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+a[0-9]" 10 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+b[0-9]" 11 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+rc[0-9]" 12 | - "[0-9]+.[0-9]+.[0-9]+a[0-9]" 13 | - "[0-9]+.[0-9]+.[0-9]+b[0-9]" 14 | - "[0-9]+.[0-9]+.[0-9]+rc[0-9]" 15 | workflow_dispatch: 16 | 17 | permissions: {} # Explicitly remove all permissions by default 18 | 19 | jobs: 20 | publish-pre-release: 21 | name: Publish Pre-release Package 22 | runs-on: ubuntu-latest 23 | environment: 24 | name: test 25 | url: https://pypi.org/project/trackers/ 26 | timeout-minutes: 10 27 | permissions: 28 | id-token: write # Required for PyPI publishing 29 | contents: read # Required for checkout 30 | strategy: 31 | matrix: 32 | python-version: ["3.10"] 33 | steps: 34 | - name: 📥 Checkout the repository 35 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 36 | 37 | - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} 38 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v6.1.0 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | activate-environment: true 42 | 43 | - name: 🏗️ Build source and wheel distributions 44 | run: | 45 | uv pip install -r pyproject.toml --group build 46 | uv build 47 | uv run twine check --strict dist/* 48 | 49 | - name: 🚀 Publish to PyPi 50 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 51 | with: 52 | attestations: true 53 | -------------------------------------------------------------------------------- /.github/workflows/publish-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Trackers Releases to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - "[0-9]+.[0-9]+[0-9]+.[0-9]+[0-9]" 7 | - "[0-9]+.[0-9]+[0-9]+.[0-9]" 8 | - "[0-9]+.[0-9]+.[0-9]" 9 | workflow_dispatch: 10 | 11 | permissions: {} # Explicitly remove all permissions by default 12 | 13 | jobs: 14 | publish-release: 15 | name: Publish Release Package 16 | runs-on: ubuntu-latest 17 | environment: 18 | name: release 19 | url: https://pypi.org/project/trackers/ 20 | timeout-minutes: 10 21 | permissions: 22 | id-token: write # Required for PyPI publishing 23 | contents: read # Required for checkout 24 | strategy: 25 | matrix: 26 | python-version: ["3.10"] 27 | steps: 28 | - name: 📥 Checkout the repository 29 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 30 | 31 | - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} 32 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v6.1.0 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | activate-environment: true 36 | 37 | - name: 🏗️ Build source and wheel distributions 38 | run: | 39 | uv pip install -r pyproject.toml --group build 40 | uv build 41 | uv run twine check --strict dist/* 42 | 43 | - name: 🚀 Publish to PyPi 44 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 45 | with: 46 | attestations: true 47 | -------------------------------------------------------------------------------- /.github/workflows/test-doc.yml: -------------------------------------------------------------------------------- 1 | name: 🧪 Docs Test WorkFlow 📚 2 | 3 | on: 4 | pull_request: 5 | branches: [main, develop] 6 | 7 | # Restrict permissions by default 8 | permissions: 9 | contents: read # Required for checkout 10 | checks: write # Required for test reporting 11 | 12 | jobs: 13 | docs-build-test: 14 | name: Test docs build 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 10 17 | strategy: 18 | matrix: 19 | python-version: ["3.10"] 20 | steps: 21 | - name: 📥 Checkout the repository 22 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 23 | with: 24 | fetch-depth: 0 25 | 26 | - name: 🐍 Install uv and set Python ${{ matrix.python-version }} 27 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v6.1.0 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | activate-environment: true 31 | 32 | - name: 🏗️ Install dependencies 33 | run: uv pip install -r pyproject.toml --group docs --python-version ${{ matrix.python-version }} 34 | 35 | - name: 🧪 Test Docs Build 36 | run: uv run mkdocs build --verbose 37 | -------------------------------------------------------------------------------- /.github/workflows/uv-test.yml: -------------------------------------------------------------------------------- 1 | name: 🔧 Pytest/Test Workflow 2 | 3 | on: 4 | pull_request: 5 | branches: [main, develop] 6 | 7 | jobs: 8 | run-tests: 9 | name: Import Test and Pytest Run 10 | timeout-minutes: 10 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest, windows-latest, macos-latest] 15 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 16 | 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - name: 📥 Checkout the repository 20 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 21 | 22 | - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} 23 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v6.1.0 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | activate-environment: true 27 | # TODO(https://github.com/astral-sh/setup-uv/issues/226): Remove this. 28 | prune-cache: ${{ matrix.os != 'windows-latest' }} 29 | 30 | - name: 🚀 Install Packages 31 | run: uv pip install -r pyproject.toml --group dev --group docs --extra cpu --extra reid 32 | 33 | - name: 🧪 Run the Import test 34 | run: uv run python -c "import trackers" 35 | 36 | - name: 🧪 Run the Test 37 | run: uv run pytest 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # Installer logs 30 | pip-log.txt 31 | pip-delete-this-directory.txt 32 | 33 | # Unit test / coverage reports 34 | htmlcov/ 35 | .tox/ 36 | .nox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *.cover 43 | *.py,cover 44 | .hypothesis/ 45 | .pytest_cache/ 46 | cover/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | db.sqlite3 56 | db.sqlite3-journal 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | .pybuilder/ 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # IPython 76 | profile_default/ 77 | ipython_config.py 78 | 79 | # pyenv 80 | # For a library or package, you might want to ignore these files since the code is 81 | # intended to run in multiple environments; otherwise, check them in: 82 | # .python-version 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # UV 92 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 93 | # This is especially recommended for binary packages to ensure reproducibility, and is more 94 | # commonly ignored for libraries. 95 | #uv.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # pytype static type analyzer 147 | .pytype/ 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | # PyCharm 153 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 154 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 155 | # and can be added to the global gitignore or merged into this file. For a more nuclear 156 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 157 | #.idea/ 158 | 159 | # Ruff stuff: 160 | .ruff_cache/ 161 | 162 | # PyPI configuration file 163 | .pypirc 164 | 165 | # Repository-specific stuff 166 | .ipynb_checkpoints/ 167 | .idea/ 168 | test.py 169 | **.pt 170 | **.pth 171 | .DS_Store 172 | data/ 173 | logs/ 174 | runs/ 175 | wandb/ 176 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | 2 | ci: 3 | autofix_prs: true 4 | autoupdate_schedule: weekly 5 | autofix_commit_msg: "fix(pre_commit): 🎨 auto format pre-commit hooks" 6 | autoupdate_commit_msg: "chore(pre_commit): ⬆ pre_commit autoupdate" 7 | 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v5.0.0 11 | hooks: 12 | - id: trailing-whitespace 13 | exclude: test/.*\.py 14 | - id: check-executables-have-shebangs 15 | - id: check-toml 16 | - id: check-case-conflict 17 | - id: check-added-large-files 18 | - id: detect-private-key 19 | - id: pretty-format-json 20 | args: ['--autofix', '--no-sort-keys', '--indent=4'] 21 | exclude: /.*\.ipynb 22 | - id: end-of-file-fixer 23 | - id: mixed-line-ending 24 | 25 | - repo: https://github.com/PyCQA/bandit 26 | rev: '1.8.3' 27 | hooks: 28 | - id: bandit 29 | args: ["-c", "pyproject.toml"] 30 | additional_dependencies: ["bandit[toml]"] 31 | 32 | - repo: https://github.com/astral-sh/ruff-pre-commit 33 | rev: v0.11.12 34 | hooks: 35 | - id: ruff 36 | args: [--fix, --exit-non-zero-on-fix] 37 | - id: ruff-format 38 | types_or: [ python, pyi, jupyter] 39 | 40 | - repo: https://github.com/pre-commit/mirrors-mypy 41 | rev: 'v1.16.0' 42 | hooks: 43 | - id: mypy 44 | additional_dependencies: [numpy,types-aiofiles] 45 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socioeconomic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | - Demonstrating empathy and kindness toward other people 21 | - Being respectful of differing opinions, viewpoints, and experiences 22 | - Giving and gracefully accepting constructive feedback 23 | - Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | - Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | - The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | - Trolling, insulting or derogatory comments, and personal or political attacks 33 | - Public or private harassment 34 | - Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | - Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | community-reports@roboflow.com. 64 | 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.1, available at 120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][mozilla coc]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][faq]. Translations are available at 127 | [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [faq]: https://www.contributor-covenant.org/faq 130 | [homepage]: https://www.contributor-covenant.org 131 | [mozilla coc]: https://github.com/mozilla/diversity 132 | [translations]: https://www.contributor-covenant.org/translations 133 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 134 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Trackers 2 | 3 | Thank you for your interest in contributing to the Trackers library! Your help—whether it’s fixing bugs, improving documentation, or adding new algorithms—is essential to the success of the project. We’re building this library with the goal of making state-of-the-art object tracking accessible under a fully open license. 4 | 5 | ## Table of Contents 6 | 7 | 1. [How to Contribute](#how-to-contribute) 8 | 2. [CLA Signing](#cla-signing) 9 | 3. [Clean Room Requirements](#clean-room-requirements) 10 | 4. [Google-Style Docstrings and Type Hints](#google-style-docstrings-and-type-hints) 11 | 5. [Reporting Bugs](#reporting-bugs) 12 | 6. [License](#license) 13 | 14 | ## How to Contribute 15 | 16 | Contributions come in many forms: improving features, fixing bugs, suggesting ideas, improving documentation, or adding new tracking methods. Here’s a high-level overview to get you started: 17 | 18 | 1. [Fork the Repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo): Click the “Fork” button on our GitHub page to create your own copy. 19 | 2. [Clone Locally](https://docs.github.com/en/enterprise-server@3.11/repositories/creating-and-managing-repositories/cloning-a-repository): Download your fork to your local development environment. 20 | 3. [Create a Branch](https://docs.github.com/en/desktop/making-changes-in-a-branch/managing-branches-in-github-desktop): Use a descriptive name to create a new branch: 21 | 22 | ```bash 23 | git checkout -b feature/your-descriptive-name 24 | ``` 25 | 26 | 4. Develop Your Changes: Make your updates, ensuring your commit messages clearly describe your modifications. 27 | 5. [Commit and Push](https://docs.github.com/en/desktop/making-changes-in-a-branch/committing-and-reviewing-changes-to-your-project-in-github-desktop): Run: 28 | 29 | ```bash 30 | git add . 31 | git commit -m "A brief description of your changes" 32 | git push -u origin your-descriptive-name 33 | ``` 34 | 35 | 6. [Open a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request): Submit your pull request against the main development branch. Please detail your changes and link any related issues. 36 | 37 | Before merging, check that all tests pass and that your changes adhere to our development and documentation standards. 38 | 39 | ## CLA Signing 40 | 41 | In order to maintain the integrity of our project, every pull request must include a signed Contributor License Agreement (CLA). This confirms that your contributions are properly licensed under our Apache 2.0 License. After opening your pull request, simply add a comment stating: 42 | 43 | ``` 44 | I have read the CLA Document and I sign the CLA. 45 | ``` 46 | 47 | This step is essential before any merge can occur. 48 | 49 | ## Clean Room Requirements 50 | 51 | Trackers package is developed under the Apache 2.0 license, which allows for wide adoption, commercial use, and integration with other open-source tools. However, many object tracking methods released alongside academic papers are published under more restrictive licenses (GPL, AGPL, etc.), which limit redistribution or usage in commercial contexts. 52 | 53 | To ensure Trackers remains fully open and legally safe to use: 54 | 55 | - All algorithms must be clean room re-implementations, meaning they are developed from scratch without referencing restricted source code. 56 | - You must not copy, adapt, or even consult source code under restrictive licenses. 57 | 58 | You can use the following as reference: 59 | 60 | - The original academic papers that describe the algorithm. 61 | - Existing implementations released under permissive open-source licenses (Apache 2.0, MIT, BSD, etc.). 62 | 63 | If in doubt about whether a license is compatible, please ask before proceeding. By contributing to this project and signing the CLA, you confirm that your work complies with these guidelines and that you understand the importance of maintaining a clean licensing chain. 64 | 65 | ## Google-Style Docstrings and Type Hints 66 | 67 | For clarity and maintainability, any new functions or classes must include [Google-style docstrings](https://google.github.io/styleguide/pyguide.html) and use Python type hints. Type hints are mandatory in all function definitions, ensuring explicit parameter and return type declarations. These docstrings should clearly explain parameters, return types, and provide usage examples when applicable. 68 | 69 | For example: 70 | 71 | ```python 72 | def sample_function(param1: int, param2: int = 10) -> bool: 73 | """ 74 | Provides a brief description of function behavior. 75 | 76 | Args: 77 | param1 (int): Explanation of the first parameter. 78 | param2 (int): Explanation of the second parameter, defaulting to 10. 79 | 80 | Returns: 81 | bool: True if the operation succeeds, otherwise False. 82 | 83 | Examples: 84 | >>> sample_function(5, 10) 85 | True 86 | """ 87 | return param1 == param2 88 | ``` 89 | 90 | Following this pattern helps ensure consistency throughout the codebase. 91 | 92 | ## Reporting Bugs 93 | 94 | Bug reports are vital for continued improvement. When reporting an issue, please include a clear, minimal reproducible example that demonstrates the problem. Detailed bug reports assist us in swiftly diagnosing and addressing issues. 95 | 96 | ## License 97 | 98 | By contributing to Trackers, you agree that your contributions will be licensed under the Apache 2.0 License as specified in our [LICENSE](/LICENSE) file. 99 | 100 | Thank you for helping us build a reliable, open-source tracking library. We’re excited to collaborate with you! 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

trackers

3 | trackers logo 4 | 5 | [![version](https://badge.fury.io/py/trackers.svg)](https://badge.fury.io/py/trackers) 6 | [![downloads](https://img.shields.io/pypi/dm/trackers)](https://pypistats.org/packages/trackers) 7 | [![license](https://img.shields.io/badge/license-Apache%202.0-blue)](https://github.com/roboflow/trackers/blob/main/LICENSE.md) 8 | [![python-version](https://img.shields.io/pypi/pyversions/trackers)](https://badge.fury.io/py/trackers) 9 | 10 | [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VT_FYIe3kborhWrfKKBqqfR0EjQeQNiO?usp=sharing) 11 | [![discord](https://img.shields.io/discord/1159501506232451173?logo=discord&label=discord&labelColor=fff&color=5865f2&link=https%3A%2F%2Fdiscord.gg%2FGbfgXGJ8Bk)](https://discord.gg/GbfgXGJ8Bk) 12 |
13 | 14 | ## Hello 15 | 16 | `trackers` is a unified library offering clean room re-implementations of leading multi-object tracking algorithms. Its modular design allows you to easily swap trackers and integrate them with object detectors from various libraries like `inference`, `ultralytics`, or `transformers`. 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 |
TrackerPaperMOTAYearStatusColab
SORTarXiv74.62016colab
DeepSORTarXiv75.42017colab
ByteTrackarXiv77.82021🚧🚧
OC-SORTarXiv75.92022🚧🚧
BoT-SORTarXiv77.82022🚧🚧
73 |
74 | 75 | https://github.com/user-attachments/assets/eef9b00a-cfe4-40f7-a495-954550e3ef1f 76 | 77 | ## Installation 78 | 79 | Pip install the `trackers` package in a [**Python>=3.9**](https://www.python.org/) environment. 80 | 81 | ```bash 82 | pip install trackers 83 | ``` 84 | 85 |
86 | install from source 87 | 88 |
89 | 90 | By installing `trackers` from source, you can explore the most recent features and enhancements that have not yet been officially released. Please note that these updates are still in development and may not be as stable as the latest published release. 91 | 92 | ```bash 93 | pip install git+https://github.com/roboflow/trackers.git 94 | ``` 95 | 96 |
97 | 98 | ## Quickstart 99 | 100 | With a modular design, `trackers` lets you combine object detectors from different libraries with the tracker of your choice. Here's how you can use `SORTTracker` with various detectors: 101 | 102 | ```python 103 | import supervision as sv 104 | from trackers import SORTTracker 105 | from inference import get_model 106 | 107 | tracker = SORTTracker() 108 | model = get_model(model_id="yolov11m-640") 109 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 110 | 111 | def callback(frame, _): 112 | result = model.infer(frame)[0] 113 | detections = sv.Detections.from_inference(result) 114 | detections = tracker.update(detections) 115 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 116 | 117 | sv.process_video( 118 | source_path="", 119 | target_path="", 120 | callback=callback, 121 | ) 122 | ``` 123 | 124 |
125 | run with ultralytics 126 | 127 |
128 | 129 | ```python 130 | import supervision as sv 131 | from trackers import SORTTracker 132 | from ultralytics import YOLO 133 | 134 | tracker = SORTTracker() 135 | model = YOLO("yolo11m.pt") 136 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 137 | 138 | def callback(frame, _): 139 | result = model(frame)[0] 140 | detections = sv.Detections.from_ultralytics(result) 141 | detections = tracker.update(detections) 142 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 143 | 144 | sv.process_video( 145 | source_path="", 146 | target_path="", 147 | callback=callback, 148 | ) 149 | ``` 150 | 151 |
152 | 153 |
154 | run with transformers 155 | 156 |
157 | 158 | ```python 159 | import torch 160 | import supervision as sv 161 | from trackers import SORTTracker 162 | from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor 163 | 164 | tracker = SORTTracker() 165 | image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd") 166 | model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd") 167 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 168 | 169 | def callback(frame, _): 170 | inputs = image_processor(images=frame, return_tensors="pt") 171 | with torch.no_grad(): 172 | outputs = model(**inputs) 173 | 174 | h, w, _ = frame.shape 175 | results = image_processor.post_process_object_detection( 176 | outputs, 177 | target_sizes=torch.tensor([(h, w)]), 178 | threshold=0.5 179 | )[0] 180 | 181 | detections = sv.Detections.from_transformers( 182 | transformers_results=results, 183 | id2label=model.config.id2label 184 | ) 185 | 186 | detections = tracker.update(detections) 187 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 188 | 189 | sv.process_video( 190 | source_path="", 191 | target_path="", 192 | callback=callback, 193 | ) 194 | ``` 195 | 196 |
197 | 198 | ## License 199 | 200 | The code is released under the [Apache 2.0 license](https://github.com/roboflow/trackers/blob/main/LICENSE). 201 | 202 | ## Contribution 203 | 204 | We welcome all contributions—whether it’s reporting issues, suggesting features, or submitting pull requests. Please read our [contributor guidelines](https://github.com/roboflow/trackers/blob/main/CONTRIBUTING.md) to learn about our processes and best practices. 205 | -------------------------------------------------------------------------------- /docs/assets/logo-trackers-black.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/assets/logo-trackers-violet.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/assets/logo-trackers-white.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | comments: true 3 | --- 4 | 5 |
6 | 7 | Trackers Logo 8 | 9 |
10 | 11 | version license python-version 12 | 13 |
14 | 15 | colab discord 16 | 17 |
18 | 19 | `trackers` is a unified library offering clean room re-implementations of leading multi-object tracking algorithms. Its modular design allows you to easily swap trackers and integrate them with object detectors from various libraries like `inference`, `ultralytics`, or `transformers`. 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 |
TrackerPaperMOTAYearStatusColab
SORTarXiv74.62016colab
DeepSORTarXiv75.42017colab
ByteTrackarXiv77.82021🚧🚧
OC-SORTarXiv75.92022🚧🚧
BoT-SORTarXiv77.82022🚧🚧
76 |
77 | 78 | # Installation 79 | 80 | You can install `trackers` in a [**Python>=3.9**](https://www.python.org/) environment. 81 | 82 | !!! example "Basic Installation" 83 | 84 | === "pip" 85 | ```bash 86 | pip install trackers 87 | ``` 88 | 89 | === "poetry" 90 | ```bash 91 | poetry add trackers 92 | ``` 93 | 94 | === "uv" 95 | ```bash 96 | uv pip install trackers 97 | ``` 98 | 99 | !!! example "Hardware Acceleration" 100 | 101 | === "CPU" 102 | ```bash 103 | pip install "trackers[cpu]" 104 | ``` 105 | 106 | === "CUDA 11.8" 107 | ```bash 108 | pip install "trackers[cu118]" 109 | ``` 110 | 111 | === "CUDA 12.4" 112 | ```bash 113 | pip install "trackers[cu124]" 114 | ``` 115 | 116 | === "CUDA 12.6" 117 | ```bash 118 | pip install "trackers[cu126]" 119 | ``` 120 | 121 | === "ROCm 6.1" 122 | ```bash 123 | pip install "trackers[rocm61]" 124 | ``` 125 | 126 | === "ROCm 6.2.4" 127 | ```bash 128 | pip install "trackers[rocm624]" 129 | ``` 130 | 131 | # Quickstart 132 | 133 | With a modular design, `trackers` lets you combine object detectors from different libraries with the tracker of your choice. Here's how you can use `SORTTracker` with various detectors: 134 | 135 | === "inference" 136 | 137 | ```python hl_lines="2 5 12" 138 | import supervision as sv 139 | from trackers import SORTTracker 140 | from inference import get_model 141 | 142 | tracker = SORTTracker() 143 | model = get_model(model_id="yolov11m-640") 144 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 145 | 146 | def callback(frame, _): 147 | result = model.infer(frame)[0] 148 | detections = sv.Detections.from_inference(result) 149 | detections = tracker.update(detections) 150 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 151 | 152 | sv.process_video( 153 | source_path="", 154 | target_path="", 155 | callback=callback, 156 | ) 157 | ``` 158 | 159 | === "rf-detr" 160 | 161 | ```python hl_lines="2 5 11" 162 | import supervision as sv 163 | from trackers import SORTTracker 164 | from rfdetr import RFDETRBase 165 | 166 | tracker = SORTTracker() 167 | model = RFDETRBase() 168 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 169 | 170 | def callback(frame, _): 171 | detections = model.predict(frame) 172 | detections = tracker.update(detections) 173 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 174 | 175 | sv.process_video( 176 | source_path="", 177 | target_path="", 178 | callback=callback, 179 | ) 180 | ``` 181 | 182 | === "ultralytics" 183 | 184 | ```python hl_lines="2 5 12" 185 | import supervision as sv 186 | from trackers import SORTTracker 187 | from ultralytics import YOLO 188 | 189 | tracker = SORTTracker() 190 | model = YOLO("yolo11m.pt") 191 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 192 | 193 | def callback(frame, _): 194 | result = model(frame)[0] 195 | detections = sv.Detections.from_ultralytics(result) 196 | detections = tracker.update(detections) 197 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 198 | 199 | sv.process_video( 200 | source_path="", 201 | target_path="", 202 | callback=callback, 203 | ) 204 | ``` 205 | 206 | === "transformers" 207 | 208 | ```python hl_lines="3 6 28" 209 | import torch 210 | import supervision as sv 211 | from trackers import SORTTracker 212 | from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor 213 | 214 | tracker = SORTTracker() 215 | processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd") 216 | model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd") 217 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 218 | 219 | def callback(frame, _): 220 | inputs = processor(images=frame, return_tensors="pt") 221 | with torch.no_grad(): 222 | outputs = model(**inputs) 223 | 224 | h, w, _ = frame.shape 225 | results = processor.post_process_object_detection( 226 | outputs, 227 | target_sizes=torch.tensor([(h, w)]), 228 | threshold=0.5 229 | )[0] 230 | 231 | detections = sv.Detections.from_transformers( 232 | transformers_results=results, 233 | id2label=model.config.id2label 234 | ) 235 | 236 | detections = tracker.update(detections) 237 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 238 | 239 | sv.process_video( 240 | source_path="", 241 | target_path="", 242 | callback=callback, 243 | ) 244 | ``` 245 | -------------------------------------------------------------------------------- /docs/overrides/partials/comments.html: -------------------------------------------------------------------------------- 1 | {% if page.meta.comments %} 2 |

{{ lang.t("meta.comments") }}

3 | 4 | 20 | 21 | 22 | 50 | {% endif %} 51 | -------------------------------------------------------------------------------- /docs/overrides/stylesheets/style.css: -------------------------------------------------------------------------------- 1 | th, td { 2 | border: 1px solid var(--md-typeset-table-color); 3 | } 4 | 5 | .md-typeset__table { 6 | line-height: 1.5; 7 | } 8 | 9 | .md-typeset__table table:not([class]) { 10 | font-size: 0.6rem; 11 | border-collapse: collapse; 12 | } 13 | 14 | .md-typeset__table table:not([class]) td, 15 | .md-typeset__table table:not([class]) th { 16 | padding: 10px; 17 | } 18 | -------------------------------------------------------------------------------- /docs/trackers/core/deepsort/tracker.md: -------------------------------------------------------------------------------- 1 | --- 2 | comments: true 3 | --- 4 | 5 | # DeepSORT 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-1703.07402-b31b1b.svg)](https://arxiv.org/abs/1703.07402) 8 | [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-track-objects-with-deepsort-tracker.ipynb) 9 | 10 | ## Overview 11 | 12 | DeepSORT extends the original [SORT](../sort/tracker.md) algorithm by integrating appearance information through a deep association metric. While maintaining the core Kalman filtering and Hungarian algorithm components from SORT, DeepSORT adds a convolutional neural network (CNN) trained on large-scale person re-identification datasets to extract appearance features from detected objects. This integration allows the tracker to maintain object identities through longer periods of occlusion, effectively reducing identity switches compared to the original SORT. DeepSORT operates with a dual-metric approach, combining motion information (Mahalanobis distance) with appearance similarity (cosine distance in feature space) to improve data association decisions. It also introduces a matching cascade that prioritizes recently seen tracks, enhancing robustness during occlusions. Most of the computational complexity is offloaded to an offline pre-training stage, allowing the online tracking component to run efficiently at approximately 20Hz, making it suitable for real-time applications while achieving competitive tracking performance with significantly improved identity preservation. 13 | 14 | 15 | ## Examples 16 | 17 | === "inference" 18 | 19 | ```python hl_lines="2 5-6 13" 20 | import supervision as sv 21 | from trackers import DeepSORTTracker, ReIDModel 22 | from inference import get_model 23 | 24 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 25 | tracker = DeepSORTTracker(reid_model=reid_model) 26 | model = get_model(model_id="yolov11m-640") 27 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 28 | 29 | def callback(frame, _): 30 | result = model.infer(frame)[0] 31 | detections = sv.Detections.from_inference(result) 32 | detections = tracker.update(detections, frame) 33 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 34 | 35 | sv.process_video( 36 | source_path="", 37 | target_path="", 38 | callback=callback, 39 | ) 40 | ``` 41 | 42 | === "rf-detr" 43 | 44 | ```python hl_lines="2 5-6 12" 45 | import supervision as sv 46 | from trackers import DeepSORTTracker, ReIDModel 47 | from rfdetr import RFDETRBase 48 | 49 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 50 | tracker = DeepSORTTracker(reid_model=reid_model) 51 | model = RFDETRBase() 52 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 53 | 54 | def callback(frame, _): 55 | detections = model.predict(frame) 56 | detections = tracker.update(detections, frame) 57 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 58 | 59 | sv.process_video( 60 | source_path="", 61 | target_path="", 62 | callback=callback, 63 | ) 64 | ``` 65 | 66 | === "ultralytics" 67 | 68 | ```python hl_lines="2 5-6 13" 69 | import supervision as sv 70 | from trackers import DeepSORTTracker, ReIDModel 71 | from ultralytics import YOLO 72 | 73 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 74 | tracker = DeepSORTTracker(reid_model=reid_model) 75 | model = YOLO("yolo11m.pt") 76 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 77 | 78 | def callback(frame, _): 79 | result = model(frame)[0] 80 | detections = sv.Detections.from_ultralytics(result) 81 | detections = tracker.update(detections, frame) 82 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 83 | 84 | sv.process_video( 85 | source_path="", 86 | target_path="", 87 | callback=callback, 88 | ) 89 | ``` 90 | 91 | === "transformers" 92 | 93 | ```python hl_lines="3 6-7 29" 94 | import torch 95 | import supervision as sv 96 | from trackers import DeepSORTTracker, ReIDModel 97 | from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor 98 | 99 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 100 | tracker = DeepSORTTracker(reid_model=reid_model) 101 | processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd") 102 | model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd") 103 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 104 | 105 | def callback(frame, _): 106 | inputs = processor(images=frame, return_tensors="pt") 107 | with torch.no_grad(): 108 | outputs = model(**inputs) 109 | 110 | h, w, _ = frame.shape 111 | results = processor.post_process_object_detection( 112 | outputs, 113 | target_sizes=torch.tensor([(h, w)]), 114 | threshold=0.5 115 | )[0] 116 | 117 | detections = sv.Detections.from_transformers( 118 | transformers_results=results, 119 | id2label=model.config.id2label 120 | ) 121 | 122 | detections = tracker.update(detections, frame) 123 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 124 | 125 | sv.process_video( 126 | source_path="", 127 | target_path="", 128 | callback=callback, 129 | ) 130 | ``` 131 | 132 | ## API 133 | 134 | !!! example "Install DeepSORT" 135 | 136 | === "CPU" 137 | ```bash 138 | pip install "trackers[reid,cpu]" 139 | ``` 140 | 141 | === "CUDA 11.8" 142 | ```bash 143 | pip install "trackers[reid,cu118]" 144 | ``` 145 | 146 | === "CUDA 12.4" 147 | ```bash 148 | pip install "trackers[reid,cu124]" 149 | ``` 150 | 151 | === "CUDA 12.6" 152 | ```bash 153 | pip install "trackers[reid,cu126]" 154 | ``` 155 | 156 | === "ROCm 6.1" 157 | ```bash 158 | pip install "trackers[reid,rocm61]" 159 | ``` 160 | 161 | === "ROCm 6.2.4" 162 | ```bash 163 | pip install "trackers[reid,rocm624]" 164 | ``` 165 | 166 | ::: trackers.core.deepsort.tracker.DeepSORTTracker 167 | -------------------------------------------------------------------------------- /docs/trackers/core/reid/reid.md: -------------------------------------------------------------------------------- 1 | --- 2 | comments: true 3 | --- 4 | 5 | # Re-Identification (ReID) 6 | 7 | Re-identification (ReID) enables object tracking systems to recognize the same object or identity across different frames—even when occlusion, appearance changes, or re-entries occur. This is essential for robust, long-term multi-object tracking. 8 | 9 | ## Installation 10 | 11 | To use ReID features in the trackers library, install the package with the appropriate dependencies for your hardware: 12 | 13 | !!! example "Install trackers with ReID support" 14 | 15 | === "CPU" 16 | ```bash 17 | pip install "trackers[reid,cpu]" 18 | ``` 19 | 20 | === "CUDA 11.8" 21 | ```bash 22 | pip install "trackers[reid,cu118]" 23 | ``` 24 | 25 | === "CUDA 12.4" 26 | ```bash 27 | pip install "trackers[reid,cu124]" 28 | ``` 29 | 30 | === "CUDA 12.6" 31 | ```bash 32 | pip install "trackers[reid,cu126]" 33 | ``` 34 | 35 | === "ROCm 6.1" 36 | ```bash 37 | pip install "trackers[reid,rocm61]" 38 | ``` 39 | 40 | === "ROCm 6.2.4" 41 | ```bash 42 | pip install "trackers[reid,rocm624]" 43 | ``` 44 | 45 | ## ReIDModel 46 | 47 | The `ReIDModel` class provides a flexible interface to extract appearance features from object detections, which can be used by trackers to associate identities across frames. 48 | 49 | ### Loading a ReIDModel 50 | 51 | You can initialize a `ReIDModel` from any supported pretrained model in the [`timm`](https://huggingface.co/docs/timm/en/index) library using the `from_timm` method. 52 | 53 | ```python 54 | from trackers import ReIDModel 55 | 56 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 57 | ``` 58 | 59 | ### Supported Models 60 | 61 | The `ReIDModel` supports all models available in the timm library. You can list available models using: 62 | 63 | ```python 64 | import timm 65 | print(timm.list_models()) 66 | ``` 67 | 68 | ### Extracting Embeddings 69 | 70 | To extract embeddings (feature vectors) from detected objects in an image frame, use the `extract_features` method. It crops each detected bounding box from the frame, applies necessary transforms, and passes the crops through the backbone model: 71 | 72 | ```python 73 | import cv2 74 | import supervision as sv 75 | from trackers import ReIDModel 76 | from inference import get_model 77 | 78 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 79 | model = get_model(model_id="yolov11m-640") 80 | 81 | image = cv2.imread("") 82 | 83 | result = model.infer(image)[0] 84 | detections = sv.Detections.from_inference(result) 85 | features = reid_model.extract_features(image, detections) 86 | ``` 87 | 88 | ## Tracking Integration 89 | 90 | ReID models are integrated into trackers like DeepSORT to improve identity association by providing appearance features alongside motion cues. 91 | 92 | ```python 93 | import supervision as sv 94 | from trackers import DeepSORTTracker, ReIDModel 95 | from inference import get_model 96 | 97 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 98 | tracker = DeepSORTTracker(reid_model=reid_model) 99 | model = get_model(model_id="yolov11m-640") 100 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 101 | 102 | def callback(frame, _): 103 | result = model.infer(frame)[0] 104 | detections = sv.Detections.from_inference(result) 105 | detections = tracker.update(detections, frame) 106 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 107 | 108 | sv.process_video( 109 | source_path="", 110 | target_path="", 111 | callback=callback, 112 | ) 113 | ``` 114 | 115 | This setup extracts appearance embeddings for detected objects and uses them in the tracker to maintain consistent IDs across frames. 116 | 117 | ## Training 118 | 119 | You can train a custom ReID model using the `TripletsDataset` class, which provides triplets of anchor, positive, and negative samples for metric learning. 120 | 121 | Fine-tuning a pre-trained ReID model or training one from scratch can be beneficial when: 122 | 123 | - Your target domain (specific camera angles, lighting, object appearances) differs significantly from the data the pre-trained model was exposed to. 124 | 125 | - You have a custom dataset featuring unique identities or appearance variations not covered by generic models. 126 | 127 | - You aim to boost performance for specific tracking scenarios where general models might underperform. This allows the model to learn features more specific to your data. 128 | 129 | ### Dataset Structure 130 | 131 | Prepare your dataset with the following directory structure, where each subfolder represents a unique identity: 132 | 133 | ```rext 134 | root/ 135 | ├── identity_1/ 136 | │ ├── image_1.png 137 | │ ├── image_2.png 138 | │ └── image_3.png 139 | ├── identity_2/ 140 | │ ├── image_1.png 141 | │ ├── image_2.png 142 | │ └── image_3.png 143 | ├── identity_3/ 144 | │ ├── image_1.png 145 | │ ├── image_2.png 146 | │ └── image_3.png 147 | ... 148 | ``` 149 | 150 | Each folder contains images of the same object or person under different conditions. 151 | 152 | ```python 153 | from torch.utils.data import DataLoader 154 | from trackers.core.reid.dataset.base import TripletsDataset 155 | from trackers import ReIDModel 156 | 157 | train_dataset = TripletsDataset.from_image_directories( 158 | root_directory="", 159 | ) 160 | 161 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 162 | 163 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 164 | 165 | reid_model.train( 166 | train_loader, 167 | epochs=10, 168 | projection_dimension=len(train_dataset), 169 | freeze_backbone=True, 170 | learning_rate=5e-4, 171 | weight_decay=1e-2, 172 | checkpoint_interval=5, 173 | ) 174 | ``` 175 | 176 | ## Metrics and Monitoring 177 | 178 | During training, the model monitors metrics such as triplet loss and triplet accuracy to evaluate embedding quality. 179 | 180 | - Triplet Loss: Encourages embeddings of the same identity to be close and different identities to be far apart. 181 | 182 | - Triplet Accuracy: Measures how often the model correctly ranks positive samples closer than negatives. 183 | 184 | You can enable logging to various backends (matplotlib, TensorBoard, Weights & Biases) during training for real-time monitoring: 185 | 186 | ```python 187 | from torch.utils.data import DataLoader 188 | from trackers.core.reid.dataset.base import TripletsDataset 189 | from trackers import ReIDModel 190 | 191 | train_dataset = TripletsDataset.from_image_directories( 192 | root_directory="", 193 | ) 194 | 195 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 196 | 197 | reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") 198 | 199 | reid_model.train( 200 | train_loader, 201 | epochs=10, 202 | projection_dimension=len(train_dataset), 203 | freeze_backbone=True, 204 | learning_rate=5e-4, 205 | weight_decay=1e-2, 206 | checkpoint_interval=5, 207 | log_to_matplotlib=True, 208 | log_to_tensorboard=True, 209 | log_to_wandb=True, 210 | ) 211 | ``` 212 | 213 | To use the logging capabilities for Matplotlib, TensorBoard, or Weights & Biases, you might need to install additional dependencies. 214 | 215 | ```bash 216 | pip install "trackers[metrics]" 217 | ``` 218 | 219 | ## Resuming from Checkpoints 220 | 221 | You can load custom-trained weights or resume training from a checkpoint: 222 | 223 | ```python 224 | from trackers import ReIDModel 225 | 226 | reid_model = ReIDModel.from_timm("") 227 | ``` 228 | 229 | ## API 230 | 231 | 232 | ::: trackers.core.reid.model.ReIDModel 233 | 234 | ::: trackers.core.reid.dataset.base.TripletsDataset 235 | -------------------------------------------------------------------------------- /docs/trackers/core/sort/tracker.md: -------------------------------------------------------------------------------- 1 | --- 2 | comments: true 3 | --- 4 | 5 | # SORT 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-1602.00763-b31b1b.svg)](https://arxiv.org/abs/1602.00763) 8 | [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-track-objects-with-sort-tracker.ipynb) 9 | 10 | ## Overview 11 | 12 | SORT (Simple Online and Realtime Tracking) is a lean, tracking-by-detection method that combines a Kalman filter for motion prediction with the Hungarian algorithm for data association. It uses object detections—commonly from a high-performing CNN-based detector—as its input, updating each tracked object’s bounding box based on linear velocity estimates. Because SORT relies on minimal appearance modeling (only bounding box geometry is used), it is extremely fast and can run comfortably at hundreds of frames per second. This speed and simplicity make it well suited for real-time applications in robotics or surveillance, where rapid, approximate solutions are essential. However, its reliance on frame-to-frame matching makes SORT susceptible to ID switches and less robust during long occlusions, since there is no built-in re-identification module. 13 | 14 | ## Examples 15 | 16 | === "inference" 17 | 18 | ```python hl_lines="2 5 12" 19 | import supervision as sv 20 | from trackers import SORTTracker 21 | from inference import get_model 22 | 23 | tracker = SORTTracker() 24 | model = get_model(model_id="yolov11m-640") 25 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 26 | 27 | def callback(frame, _): 28 | result = model.infer(frame)[0] 29 | detections = sv.Detections.from_inference(result) 30 | detections = tracker.update(detections) 31 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 32 | 33 | sv.process_video( 34 | source_path="", 35 | target_path="", 36 | callback=callback, 37 | ) 38 | ``` 39 | 40 | === "rf-detr" 41 | 42 | ```python hl_lines="2 5 11" 43 | import supervision as sv 44 | from trackers import SORTTracker 45 | from rfdetr import RFDETRBase 46 | 47 | tracker = SORTTracker() 48 | model = RFDETRBase() 49 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 50 | 51 | def callback(frame, _): 52 | detections = model.predict(frame) 53 | detections = tracker.update(detections) 54 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 55 | 56 | sv.process_video( 57 | source_path="", 58 | target_path="", 59 | callback=callback, 60 | ) 61 | ``` 62 | 63 | === "ultralytics" 64 | 65 | ```python hl_lines="2 5 12" 66 | import supervision as sv 67 | from trackers import SORTTracker 68 | from ultralytics import YOLO 69 | 70 | tracker = SORTTracker() 71 | model = YOLO("yolo11m.pt") 72 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 73 | 74 | def callback(frame, _): 75 | result = model(frame)[0] 76 | detections = sv.Detections.from_ultralytics(result) 77 | detections = tracker.update(detections) 78 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 79 | 80 | sv.process_video( 81 | source_path="", 82 | target_path="", 83 | callback=callback, 84 | ) 85 | ``` 86 | 87 | === "transformers" 88 | 89 | ```python hl_lines="3 6 28" 90 | import torch 91 | import supervision as sv 92 | from trackers import SORTTracker 93 | from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor 94 | 95 | tracker = SORTTracker() 96 | processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd") 97 | model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd") 98 | annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) 99 | 100 | def callback(frame, _): 101 | inputs = processor(images=frame, return_tensors="pt") 102 | with torch.no_grad(): 103 | outputs = model(**inputs) 104 | 105 | h, w, _ = frame.shape 106 | results = processor.post_process_object_detection( 107 | outputs, 108 | target_sizes=torch.tensor([(h, w)]), 109 | threshold=0.5 110 | )[0] 111 | 112 | detections = sv.Detections.from_transformers( 113 | transformers_results=results, 114 | id2label=model.config.id2label 115 | ) 116 | 117 | detections = tracker.update(detections) 118 | return annotator.annotate(frame, detections, labels=detections.tracker_id) 119 | 120 | sv.process_video( 121 | source_path="", 122 | target_path="", 123 | callback=callback, 124 | ) 125 | ``` 126 | 127 | ## API 128 | 129 | ::: trackers.core.sort.tracker.SORTTracker 130 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Trackers 2 | site_url: https://roboflow.github.io/trackers/develop/ 3 | site_author: Roboflow 4 | site_description: A unified library for object tracking featuring clean room re-implementations of leading multi-object tracking algorithms. 5 | repo_name: roboflow/trackers 6 | repo_url: https://github.com/roboflow/trackers 7 | edit_uri: https://github.com/roboflow/trackers/tree/main/docs 8 | copyright: Roboflow 2025. All rights reserved. 9 | 10 | extra: 11 | social: 12 | - icon: fontawesome/brands/github 13 | link: https://github.com/roboflow 14 | - icon: fontawesome/brands/python 15 | link: https://pypi.org/project/trackers 16 | - icon: fontawesome/brands/docker 17 | link: https://hub.docker.com/u/roboflow 18 | - icon: fontawesome/brands/youtube 19 | link: https://www.youtube.com/roboflow 20 | - icon: fontawesome/brands/linkedin 21 | link: https://www.linkedin.com/company/roboflow-ai/ 22 | - icon: fontawesome/brands/x-twitter 23 | link: https://twitter.com/roboflow 24 | - icon: fontawesome/brands/discord 25 | link: https://discord.gg/GbfgXGJ8Bk 26 | 27 | theme: 28 | name: material 29 | custom_dir: docs/overrides/ 30 | icon: 31 | edit: material/pencil 32 | logo: assets/logo-trackers-white.svg 33 | favicon: assets/logo-trackers-black.svg 34 | features: 35 | - content.code.annotate 36 | - content.code.copy 37 | - content.code.select 38 | - content.tabs.link 39 | - content.tooltips 40 | - navigation.tracking 41 | - navigation.expand 42 | 43 | palette: 44 | # Palette toggle for light mode 45 | - media: "(prefers-color-scheme: light)" 46 | scheme: default 47 | primary: deep purple 48 | accent: deep purple 49 | toggle: 50 | icon: material/brightness-7 51 | name: Switch to dark mode 52 | # Palette toggle for dark mode 53 | - media: "(prefers-color-scheme: dark)" 54 | scheme: slate 55 | primary: deep purple 56 | accent: deep purple 57 | toggle: 58 | icon: material/brightness-4 59 | name: Switch to light mode 60 | 61 | extra_css: 62 | - stylesheets/style.css 63 | 64 | markdown_extensions: 65 | # Code syntax highlighting with line numbers and anchors 66 | - pymdownx.highlight: 67 | # Adds anchors to line numbers 68 | anchor_linenums: true 69 | # Wraps lines in span elements 70 | line_spans: __span 71 | # Adds language class to code blocks 72 | pygments_lang_class: true 73 | # Enables inline code highlighting 74 | - pymdownx.inlinehilite 75 | # Allows including content from other files 76 | - pymdownx.snippets 77 | # Enables nested code blocks and custom fences 78 | - pymdownx.superfences 79 | # Adds support for callouts/notes/warnings 80 | - admonition 81 | # Enables collapsible blocks (expandable content) 82 | - pymdownx.details 83 | # Creates tabbed content (like installation examples) 84 | - pymdownx.tabbed: 85 | # Uses an alternative styling for tabs 86 | alternate_style: true 87 | 88 | plugins: 89 | - mkdocstrings: 90 | handlers: 91 | python: 92 | options: 93 | # Controls whether to show symbol types in the table of contents 94 | show_symbol_type_toc: true 95 | # Controls whether to show symbol type in the heading 96 | show_symbol_type_heading: true 97 | # Controls whether to show the root heading (module/class name) 98 | show_root_heading: true 99 | # Controls whether to show the source code 100 | show_source: false 101 | # Specifies the docstring style to parse (Google style in this case) 102 | docstring_style: google 103 | # Controls the order of members (by source order in this case) 104 | members_order: source 105 | # Controls whether to sort members alphabetically 106 | sort_members: false 107 | 108 | nav: 109 | - Home: index.md 110 | - Trackers: 111 | - SORT: trackers/core/sort/tracker.md 112 | - DeepSORT: trackers/core/deepsort/tracker.md 113 | - ReID: trackers/core/reid/reid.md 114 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.9 3 | plugins = numpy.typing.mypy_plugin 4 | 5 | [mypy-requests] 6 | ignore_missing_imports = True 7 | 8 | [mypy-torch] 9 | ignore_missing_imports = True 10 | 11 | [mypy-torchvision] 12 | ignore_missing_imports = True 13 | 14 | [mypy-torchvision.transforms] 15 | ignore_missing_imports = True 16 | 17 | [mypy-firerequests] 18 | ignore_missing_imports = True 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "trackers" 3 | version = "2.0.1" 4 | description = "A unified library for object tracking featuring clean room re-implementations of leading multi-object tracking algorithms" 5 | readme = "README.md" 6 | authors = [ 7 | {name = "Piotr Skalski", email = "piotr.skalski92@gmail.com"}, 8 | {name = "Soumik Rakshit", email = "soumik@roboflow.com"}, 9 | ] 10 | license = {text = "Apache License 2.0"} 11 | requires-python = ">=3.9" 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Education", 16 | "Intended Audience :: Science/Research", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: 3.13", 24 | "Programming Language :: Python :: 3 :: Only", 25 | "Topic :: Software Development", 26 | "Topic :: Scientific/Engineering", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "Typing :: Typed", 29 | "Operating System :: POSIX", 30 | "Operating System :: Unix", 31 | "Operating System :: MacOS", 32 | "Operating System :: Microsoft :: Windows", 33 | ] 34 | keywords = ["tracking","mot","sort","deepsort", "machine-learning", "deep-learning", "vision", "ML", "DL", "AI", "DETR", "YOLO", "Roboflow"] 35 | 36 | dependencies = [ 37 | "aiofiles>=24.1.0", 38 | "aiohttp>=3.11.16", 39 | "numpy>=2.0.2", 40 | "supervision>=0.26.0rc7", 41 | "tqdm>=4.67.1", 42 | "validators>=0.34.0", 43 | "scipy>=1.13.1", 44 | ] 45 | 46 | [project.optional-dependencies] 47 | 48 | cpu = [ 49 | "torch>=2.6.0", 50 | "torchvision>=0.21.0", 51 | ] 52 | 53 | cu126 = [ 54 | "torch>=2.6.0", 55 | "torchvision>=0.21.0", 56 | ] 57 | 58 | cu124 = [ 59 | "torch>=2.6.0", 60 | "torchvision>=0.21.0", 61 | ] 62 | 63 | cu118 = [ 64 | "torch>=2.6.0", 65 | "torchvision>=0.21.0", 66 | ] 67 | 68 | rocm61 = [ 69 | "torch>=2.6.0", 70 | "torchvision>=0.21.0", 71 | "pytorch-triton-rocm>=2.0.0", 72 | ] 73 | 74 | rocm624 = [ 75 | "torch>=2.6.0", 76 | "torchvision>=0.21.0", 77 | "pytorch-triton-rocm>=2.0.0", 78 | ] 79 | 80 | reid = [ 81 | "safetensors>=0.5.3", 82 | "timm>=1.0.15", 83 | ] 84 | 85 | metrics = [ 86 | "matplotlib>=3.9.4", 87 | "tensorboard>=2.19.0", 88 | "wandb>=0.19.11", 89 | ] 90 | 91 | [dependency-groups] 92 | dev = [ 93 | "uv>=0.4.20", 94 | "pytest>=8.3.3", 95 | "ruff>=0.6.9", 96 | "bandit>=1.8.3", 97 | "mypy>=1.15.0", 98 | "pre-commit>=4.2.0", 99 | ] 100 | 101 | 102 | docs = [ 103 | "jupyter>=1.1.1", 104 | "mkdocs>=1.6.1", 105 | "mkdocs-glightbox>=0.4.0", 106 | "mkdocs-jupyter>=0.25.1", 107 | "mkdocs-material>=9.6.11", 108 | "mkdocs-minify-plugin>=0.8.0", 109 | "mkdocstrings>=0.29.1", 110 | "mkdocstrings-python>=1.16.10", 111 | "mike>=2.1.3", 112 | ] 113 | 114 | build = [ 115 | "twine>=5.1.1", 116 | "wheel>=0.40", 117 | "build>=0.10" 118 | ] 119 | 120 | mypy-types = [ 121 | "types-aiofiles>=24.1.0.20250326", 122 | "types-requests>=2.32.0.20250328", 123 | "types-tqdm>=4.67.0.20250417", 124 | ] 125 | 126 | 127 | [tool.uv] 128 | 129 | conflicts = [ 130 | [ 131 | { extra = "cpu" }, 132 | { extra = "cu118" }, 133 | { extra = "cu124" }, 134 | { extra = "cu126" }, 135 | { extra = "rocm61" }, 136 | { extra = "rocm624" }, 137 | ], 138 | ] 139 | 140 | [tool.uv.sources] 141 | torch = [ 142 | { index = "pytorch-cpu", extra = "cpu" }, 143 | { index = "pytorch-cu118", extra = "cu118", marker = "sys_platform != 'darwin'"}, 144 | { index = "pytorch-cu124", extra = "cu124", marker = "sys_platform != 'darwin'"}, 145 | { index = "pytorch-cu126", extra = "cu126", marker = "sys_platform != 'darwin'"}, 146 | { index = "pytorch-rocm61", extra = "rocm61", marker = "sys_platform != 'darwin'"}, 147 | { index = "pytorch-rocm624", extra = "rocm624", marker = "sys_platform != 'darwin'"}, 148 | ] 149 | torchvision = [ 150 | { index = "pytorch-cpu", extra = "cpu" }, 151 | { index = "pytorch-cu118", extra = "cu118", marker = "sys_platform != 'darwin'"}, 152 | { index = "pytorch-cu124", extra = "cu124", marker = "sys_platform != 'darwin'"}, 153 | { index = "pytorch-cu126", extra = "cu126", marker = "sys_platform != 'darwin'"}, 154 | { index = "pytorch-rocm61", extra = "rocm61", marker = "sys_platform != 'darwin'"}, 155 | { index = "pytorch-rocm624", extra = "rocm624", marker = "sys_platform != 'darwin'"}, 156 | ] 157 | 158 | pytorch-triton-rocm = [ 159 | { index = "pytorch-rocm61", extra = "rocm61", marker = "sys_platform != 'darwin'"}, 160 | { index = "pytorch-rocm624", extra = "rocm624", marker = "sys_platform != 'darwin'"}, 161 | ] 162 | 163 | [[tool.uv.index]] 164 | name = "pytorch-cpu" 165 | url = "https://download.pytorch.org/whl/cpu" 166 | explicit = true 167 | 168 | [[tool.uv.index]] 169 | name = "pytorch-cu118" 170 | url = "https://download.pytorch.org/whl/cu118" 171 | explicit = true 172 | 173 | [[tool.uv.index]] 174 | name = "pytorch-cu124" 175 | url = "https://download.pytorch.org/whl/cu124" 176 | explicit = true 177 | 178 | [[tool.uv.index]] 179 | name = "pytorch-cu126" 180 | url = "https://download.pytorch.org/whl/cu126" 181 | explicit = true 182 | 183 | [[tool.uv.index]] 184 | name = "pytorch-rocm61" 185 | url = "https://download.pytorch.org/whl/rocm6.1" 186 | explicit = true 187 | 188 | [[tool.uv.index]] 189 | name = "pytorch-rocm624" 190 | url = "https://download.pytorch.org/whl/rocm6.2.4" 191 | explicit = true 192 | 193 | 194 | [build-system] 195 | requires = ["hatchling"] 196 | build-backend = "hatchling.build" 197 | 198 | [tool.hatch.build.targets.wheel] 199 | packages = ["trackers"] 200 | 201 | [tool.bandit] 202 | target = ["test", "trackers"] 203 | tests = ["B201", "B301", "B318", "B314", "B303", "B413", "B412"] 204 | 205 | [tool.ruff] 206 | target-version = "py39" 207 | 208 | # Exclude a variety of commonly ignored directories. 209 | exclude = [ 210 | ".bzr", 211 | ".direnv", 212 | ".eggs", 213 | ".git", 214 | ".git-rewrite", 215 | ".hg", 216 | ".mypy_cache", 217 | ".nox", 218 | ".pants.d", 219 | ".pytype", 220 | ".ruff_cache", 221 | ".svn", 222 | ".tox", 223 | ".venv", 224 | "__pypackages__", 225 | "_build", 226 | "buck-out", 227 | "build", 228 | "dist", 229 | "node_modules", 230 | "venv", 231 | "yarn-error.log", 232 | "yarn.lock", 233 | "docs", 234 | ] 235 | 236 | line-length = 88 237 | indent-width = 4 238 | 239 | [tool.ruff.lint] 240 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 241 | select = ["E", "F", "I", "A", "Q", "W", "RUF"] 242 | ignore = [] 243 | # Allow autofix for all enabled rules (when `--fix`) is provided. 244 | fixable = [ 245 | "A", 246 | "B", 247 | "C", 248 | "D", 249 | "E", 250 | "F", 251 | "G", 252 | "I", 253 | "N", 254 | "Q", 255 | "S", 256 | "T", 257 | "W", 258 | "ANN", 259 | "ARG", 260 | "BLE", 261 | "COM", 262 | "DJ", 263 | "DTZ", 264 | "EM", 265 | "ERA", 266 | "EXE", 267 | "FBT", 268 | "ICN", 269 | "INP", 270 | "ISC", 271 | "NPY", 272 | "PD", 273 | "PGH", 274 | "PIE", 275 | "PL", 276 | "PT", 277 | "PTH", 278 | "PYI", 279 | "RET", 280 | "RSE", 281 | "RUF", 282 | "SIM", 283 | "SLF", 284 | "TCH", 285 | "TID", 286 | "TRY", 287 | "UP", 288 | "YTT", 289 | ] 290 | unfixable = [] 291 | # Allow unused variables when underscore-prefixed. 292 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 293 | pylint.max-args = 5 # Default is 5 294 | 295 | [tool.ruff.lint.flake8-quotes] 296 | inline-quotes = "double" 297 | multiline-quotes = "double" 298 | docstring-quotes = "double" 299 | 300 | [tool.ruff.lint.pydocstyle] 301 | convention = "google" 302 | 303 | [tool.ruff.lint.per-file-ignores] 304 | "__init__.py" = ["E402", "F401"] 305 | 306 | [tool.ruff.lint.mccabe] 307 | # Flag errors (`C901`) whenever the complexity level exceeds 10. 308 | max-complexity = 10 309 | 310 | [tool.ruff.lint.isort] 311 | order-by-type = true 312 | no-sections = false 313 | 314 | [tool.ruff.format] 315 | # Like Black, use double quotes for strings. 316 | quote-style = "double" 317 | 318 | # Like Black, indent with spaces, rather than tabs. 319 | indent-style = "space" 320 | 321 | # Like Black, respect magic trailing commas. 322 | skip-magic-trailing-comma = false 323 | 324 | # Like Black, automatically detect the appropriate line ending. 325 | line-ending = "auto" 326 | 327 | [tool.pytest.ini_options] 328 | pythonpath = "." 329 | testpaths = ["test"] 330 | -------------------------------------------------------------------------------- /test/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/test/core/__init__.py -------------------------------------------------------------------------------- /test/core/reid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/test/core/reid/__init__.py -------------------------------------------------------------------------------- /test/core/reid/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/test/core/reid/dataset/__init__.py -------------------------------------------------------------------------------- /test/core/reid/dataset/test_base.py: -------------------------------------------------------------------------------- 1 | from contextlib import ExitStack as DoesNotRaise 2 | 3 | import pytest 4 | 5 | from trackers.core.reid.dataset.base import TripletsDataset 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "tracker_id_to_images, exception", 10 | [ 11 | ( 12 | {"0111": []}, 13 | pytest.raises(ValueError), 14 | ), # Single tracker with no images - should raise ValueError 15 | ( 16 | {"0111": ["0111_00000000.jpg"]}, 17 | pytest.raises(ValueError), 18 | ), # Single tracker with one image - should raise ValueError 19 | ( 20 | {"0111": ["0111_00000000.jpg", "0111_00000001.jpg"]}, 21 | pytest.raises(ValueError), 22 | ), # Single tracker with multiple images - should raise ValueError 23 | ( 24 | { 25 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 26 | "0112": ["0112_00000000.jpg"], 27 | }, 28 | pytest.raises(ValueError), 29 | ), # Two trackers but one has only one image - should raise ValueError 30 | ( 31 | { 32 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 33 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 34 | }, 35 | DoesNotRaise(), 36 | ), # Two trackers with multiple images - should not raise 37 | ( 38 | { 39 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 40 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 41 | "0113": ["0113_00000000.jpg"], 42 | }, 43 | DoesNotRaise(), 44 | ), # Three trackers, one with fewer images - should validate dataset length 45 | ], 46 | ) 47 | def test_triplet_dataset_initialization(tracker_id_to_images, exception): 48 | with exception: 49 | _ = TripletsDataset(tracker_id_to_images) 50 | 51 | 52 | @pytest.mark.parametrize( 53 | "tracker_id_to_images, split_ratio, expected_train_size, expected_val_size, exception", # noqa: E501 54 | [ 55 | ( 56 | { 57 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 58 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 59 | }, 60 | 0.5, 61 | 1, 62 | 1, 63 | pytest.raises(ValueError), 64 | ), # Split results in only 1 tracker in test set - should raise ValueError 65 | ( 66 | { 67 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 68 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 69 | "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], 70 | "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], 71 | "0115": ["0115_00000000.jpg", "0115_00000001.jpg"], 72 | }, 73 | 0.2, 74 | 1, 75 | 4, 76 | pytest.raises(ValueError), 77 | ), # Split results in only 1 tracker in test set - should raise ValueError 78 | ( 79 | { 80 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 81 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 82 | "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], 83 | "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], 84 | "0115": ["0115_00000000.jpg", "0115_00000001.jpg"], 85 | }, 86 | 0.8, 87 | 4, 88 | 1, 89 | pytest.raises(ValueError), 90 | ), # Split results in only 1 tracker in val set - should raise ValueError 91 | ( 92 | { 93 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 94 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 95 | "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], 96 | "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], 97 | "0115": ["0115_00000000.jpg", "0115_00000001.jpg"], 98 | }, 99 | 0.6, 100 | 3, 101 | 2, 102 | DoesNotRaise(), 103 | ), # Valid split with multiple trackers in both sets 104 | ( 105 | { 106 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 107 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 108 | "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], 109 | "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], 110 | }, 111 | 0.5, 112 | 2, 113 | 2, 114 | DoesNotRaise(), 115 | ), # 50% train, 50% validation - valid 116 | ], 117 | ) 118 | def test_triplet_dataset_split( 119 | tracker_id_to_images, split_ratio, expected_train_size, expected_val_size, exception 120 | ): 121 | with exception: 122 | dataset = TripletsDataset(tracker_id_to_images) 123 | train_dataset, val_dataset = dataset.split(split_ratio=split_ratio) 124 | 125 | assert len(train_dataset) == expected_train_size, ( 126 | f"Expected train dataset size {expected_train_size}, " 127 | f"got {len(train_dataset)}" 128 | ) 129 | assert len(val_dataset) == expected_val_size, ( 130 | f"Expected validation dataset size {expected_val_size}, " 131 | f"got {len(val_dataset)}" 132 | ) 133 | 134 | 135 | @pytest.mark.parametrize( 136 | "tracker_id_to_images, tracker_id, exception", 137 | [ 138 | ( 139 | { 140 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 141 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 142 | "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], 143 | }, 144 | "0111", 145 | DoesNotRaise(), 146 | ), 147 | ], 148 | ) 149 | def test_get_triplet_image_paths(tracker_id_to_images, tracker_id, exception) -> None: 150 | with exception: 151 | dataset = TripletsDataset(tracker_id_to_images) 152 | anchor_path, positive_path, negative_path = dataset._get_triplet_image_paths( 153 | tracker_id 154 | ) 155 | 156 | assert anchor_path in tracker_id_to_images[tracker_id] 157 | assert positive_path in tracker_id_to_images[tracker_id] 158 | assert negative_path not in tracker_id_to_images[tracker_id] 159 | assert anchor_path != positive_path 160 | -------------------------------------------------------------------------------- /test/core/reid/dataset/test_market_1501.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from trackers.core.reid.dataset.market_1501 import parse_market1501_dataset 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "mock_glob_output, expected_result", 10 | [ 11 | ( 12 | # Empty dataset 13 | [], 14 | {}, 15 | ), 16 | ( 17 | # Single image for one person 18 | ["0111_00000000.jpg"], 19 | {"0111": ["0111_00000000.jpg"]}, 20 | ), 21 | ( 22 | # Multiple images for one person 23 | ["0111_00000000.jpg", "0111_00000001.jpg"], 24 | {"0111": ["0111_00000000.jpg", "0111_00000001.jpg"]}, 25 | ), 26 | ( 27 | # Multiple people with multiple images 28 | [ 29 | "0111_00000000.jpg", 30 | "0111_00000001.jpg", 31 | "0112_00000000.jpg", 32 | "0112_00000001.jpg", 33 | ], 34 | { 35 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 36 | "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], 37 | }, 38 | ), 39 | ( 40 | # Multiple people with varying number of images 41 | ["0111_00000000.jpg", "0111_00000001.jpg", "0112_00000000.jpg"], 42 | { 43 | "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], 44 | "0112": ["0112_00000000.jpg"], 45 | }, 46 | ), 47 | ], 48 | ) 49 | def test_parse_market1501_dataset(mock_glob_output, expected_result): 50 | with patch("glob.glob", return_value=mock_glob_output): 51 | result = parse_market1501_dataset("dummy_path") 52 | assert result == expected_result 53 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py39,py310,py311,py312,py313 3 | 4 | [testenv] 5 | changedir = test 6 | deps = pytest 7 | commands = pytest 8 | -------------------------------------------------------------------------------- /trackers/__init__.py: -------------------------------------------------------------------------------- 1 | from trackers.core.sort.tracker import SORTTracker 2 | from trackers.log import get_logger 3 | 4 | __all__ = ["SORTTracker"] 5 | 6 | logger = get_logger(__name__) 7 | 8 | try: 9 | from trackers.core.deepsort.tracker import DeepSORTTracker 10 | from trackers.core.reid.model import ReIDModel 11 | 12 | __all__.extend(["DeepSORTTracker", "ReIDModel"]) 13 | except ImportError: 14 | logger.warning( 15 | "ReIDModel dependencies not installed. ReIDModel will not be available. " 16 | "Please run `pip install trackers[reid]` and try again." 17 | ) 18 | pass 19 | -------------------------------------------------------------------------------- /trackers/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/trackers/core/__init__.py -------------------------------------------------------------------------------- /trackers/core/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import supervision as sv 5 | 6 | 7 | class BaseTracker(ABC): 8 | @abstractmethod 9 | def update(self, detections: sv.Detections) -> sv.Detections: 10 | pass 11 | 12 | @abstractmethod 13 | def reset(self) -> None: 14 | pass 15 | 16 | 17 | class BaseTrackerWithFeatures(ABC): 18 | @abstractmethod 19 | def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections: 20 | pass 21 | 22 | @abstractmethod 23 | def reset(self) -> None: 24 | pass 25 | -------------------------------------------------------------------------------- /trackers/core/deepsort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/trackers/core/deepsort/__init__.py -------------------------------------------------------------------------------- /trackers/core/deepsort/kalman_box_tracker.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | 5 | 6 | class DeepSORTKalmanBoxTracker: 7 | """ 8 | The `DeepSORTKalmanBoxTracker` class represents the internals of a single 9 | tracked object (bounding box), with a Kalman filter to predict and update 10 | its position. It also maintains a feature vector for the object, which is 11 | used to identify the object across frames. 12 | 13 | Attributes: 14 | tracker_id (int): Unique identifier for the tracker. 15 | number_of_successful_updates (int): Number of times the object has been 16 | updated successfully. 17 | time_since_update (int): Number of frames since the last update. 18 | state (np.ndarray): State vector of the bounding box. 19 | F (np.ndarray): State transition matrix. 20 | H (np.ndarray): Measurement matrix. 21 | Q (np.ndarray): Process noise covariance matrix. 22 | R (np.ndarray): Measurement noise covariance matrix. 23 | P (np.ndarray): Error covariance matrix. 24 | features (list[np.ndarray]): List of feature vectors. 25 | count_id (int): Class variable to assign unique IDs to each tracker. 26 | 27 | Args: 28 | bbox (np.ndarray): Initial bounding box in the form [x1, y1, x2, y2]. 29 | feature (Optional[np.ndarray]): Optional initial feature vector. 30 | """ 31 | 32 | count_id = 0 33 | 34 | @classmethod 35 | def get_next_tracker_id(cls) -> int: 36 | """ 37 | Class method that returns the next available tracker ID. 38 | 39 | Returns: 40 | int: The next available tracker ID. 41 | """ 42 | next_id = cls.count_id 43 | cls.count_id += 1 44 | return next_id 45 | 46 | def __init__(self, bbox: np.ndarray, feature: Optional[np.ndarray] = None): 47 | # Initialize with a temporary ID of -1 48 | # Will be assigned a real ID when the track is considered mature 49 | self.tracker_id = -1 50 | 51 | # Number of hits indicates how many times the object has been 52 | # updated successfully 53 | self.number_of_successful_updates = 1 54 | # Number of frames since the last update 55 | self.time_since_update = 0 56 | 57 | # For simplicity, we keep a small state vector: 58 | # (x, y, x2, y2, vx, vy, vx2, vy2). 59 | # We'll store the bounding box in "self.state" 60 | self.state = np.zeros((8, 1), dtype=np.float32) 61 | 62 | # Initialize state directly from the first detection 63 | self.state[0] = bbox[0] 64 | self.state[1] = bbox[1] 65 | self.state[2] = bbox[2] 66 | self.state[3] = bbox[3] 67 | 68 | # Basic constant velocity model 69 | self._initialize_kalman_filter() 70 | 71 | # Initialize features list 72 | self.features: list[np.ndarray] = [] 73 | if feature is not None: 74 | self.features.append(feature) 75 | 76 | def _initialize_kalman_filter(self) -> None: 77 | """ 78 | Sets up the matrices for the Kalman filter. 79 | """ 80 | # State transition matrix (F): 8x8 81 | # We assume a constant velocity model. Positions are incremented by 82 | # velocity each step. 83 | self.F = np.eye(8, dtype=np.float32) 84 | for i in range(4): 85 | self.F[i, i + 4] = 1.0 86 | 87 | # Measurement matrix (H): we directly measure x1, y1, x2, y2 88 | self.H = np.eye(4, 8, dtype=np.float32) # 4x8 89 | 90 | # Process covariance matrix (Q) 91 | self.Q = np.eye(8, dtype=np.float32) * 0.01 92 | 93 | # Measurement covariance (R): noise in detection 94 | self.R = np.eye(4, dtype=np.float32) * 0.1 95 | 96 | # Error covariance matrix (P) 97 | self.P = np.eye(8, dtype=np.float32) 98 | 99 | def predict(self) -> None: 100 | """ 101 | Predict the next state of the bounding box (applies the state transition). 102 | """ 103 | # Predict state 104 | self.state = self.F @ self.state 105 | # Predict error covariance 106 | self.P = self.F @ self.P @ self.F.T + self.Q 107 | 108 | # Increase time since update 109 | self.time_since_update += 1 110 | 111 | def update(self, bbox: np.ndarray) -> None: 112 | """ 113 | Updates the state with a new detected bounding box. 114 | 115 | Args: 116 | bbox (np.ndarray): Detected bounding box in the form [x1, y1, x2, y2]. 117 | """ 118 | self.time_since_update = 0 119 | self.number_of_successful_updates += 1 120 | 121 | # Kalman Gain 122 | S = self.H @ self.P @ self.H.T + self.R 123 | K = self.P @ self.H.T @ np.linalg.inv(S) 124 | 125 | # Residual 126 | measurement = bbox.reshape((4, 1)) 127 | y = measurement - self.H @ self.state 128 | 129 | # Update state 130 | self.state = self.state + K @ y 131 | 132 | # Update covariance 133 | identity_matrix = np.eye(8, dtype=np.float32) 134 | self.P = (identity_matrix - K @ self.H) @ self.P 135 | 136 | def get_state_bbox(self) -> np.ndarray: 137 | """ 138 | Returns the current bounding box estimate from the state vector. 139 | 140 | Returns: 141 | np.ndarray: The bounding box [x1, y1, x2, y2]. 142 | """ 143 | return np.array( 144 | [ 145 | self.state[0], # x1 146 | self.state[1], # y1 147 | self.state[2], # x2 148 | self.state[3], # y2 149 | ], 150 | dtype=float, 151 | ).reshape(-1) 152 | 153 | def update_feature(self, feature: np.ndarray): 154 | self.features.append(feature) 155 | 156 | def get_feature(self) -> Union[np.ndarray, None]: 157 | """ 158 | Get the mean feature vector for this tracker. 159 | 160 | Returns: 161 | np.ndarray: Mean feature vector. 162 | """ 163 | if len(self.features) > 0: 164 | # Return the mean of all features, thus (in theory) capturing the 165 | # "average appearance" of the object, which should be more robust 166 | # to minor appearance changes. Otherwise, the last feature can 167 | # also be returned like the following: 168 | # return self.features[-1] 169 | return np.mean(self.features, axis=0) 170 | return None 171 | -------------------------------------------------------------------------------- /trackers/core/deepsort/tracker.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import supervision as sv 5 | from scipy.spatial.distance import cdist 6 | 7 | from trackers.core.base import BaseTrackerWithFeatures 8 | from trackers.core.deepsort.kalman_box_tracker import DeepSORTKalmanBoxTracker 9 | from trackers.core.reid import ReIDModel 10 | from trackers.utils.sort_utils import ( 11 | get_alive_trackers, 12 | get_iou_matrix, 13 | update_detections_with_track_ids, 14 | ) 15 | 16 | 17 | class DeepSORTTracker(BaseTrackerWithFeatures): 18 | """Implements DeepSORT (Deep Simple Online and Realtime Tracking). 19 | 20 | DeepSORT extends SORT by integrating appearance information using a deep 21 | learning model, improving tracking through occlusions and reducing ID switches. 22 | It combines motion (Kalman filter) and appearance cues for data association. 23 | 24 | Args: 25 | reid_model (ReIDModel): An instance of a `ReIDModel` to extract 26 | appearance features. 27 | device (Optional[str]): Device to run the feature extraction 28 | model on (e.g., 'cpu', 'cuda'). 29 | lost_track_buffer (int): Number of frames to buffer when a track is lost. 30 | Enhances occlusion handling but may increase ID switches for similar objects. 31 | frame_rate (float): Frame rate of the video (frames per second). 32 | Used to calculate the maximum time a track can be lost. 33 | track_activation_threshold (float): Detection confidence threshold 34 | for track activation. Higher values reduce false positives 35 | but might miss objects. 36 | minimum_consecutive_frames (int): Number of consecutive frames an object 37 | must be tracked to be considered 'valid'. Prevents spurious tracks but 38 | may miss short tracks. 39 | minimum_iou_threshold (float): IOU threshold for gating in the matching cascade. 40 | appearance_threshold (float): Cosine distance threshold for appearance matching. 41 | Only matches below this threshold are considered valid. 42 | appearance_weight (float): Weight (0-1) balancing motion (IOU) and appearance 43 | distance in the combined matching cost. 44 | distance_metric (str): Distance metric for appearance features (e.g., 'cosine', 45 | 'euclidean'). See `scipy.spatial.distance.cdist`. 46 | """ # noqa: E501 47 | 48 | def __init__( 49 | self, 50 | reid_model: ReIDModel, 51 | device: Optional[str] = None, 52 | lost_track_buffer: int = 30, 53 | frame_rate: float = 30.0, 54 | track_activation_threshold: float = 0.25, 55 | minimum_consecutive_frames: int = 3, 56 | minimum_iou_threshold: float = 0.3, 57 | appearance_threshold: float = 0.7, 58 | appearance_weight: float = 0.5, 59 | distance_metric: str = "cosine", 60 | ): 61 | self.reid_model = reid_model 62 | self.lost_track_buffer = lost_track_buffer 63 | self.frame_rate = frame_rate 64 | self.minimum_consecutive_frames = minimum_consecutive_frames 65 | self.minimum_iou_threshold = minimum_iou_threshold 66 | self.track_activation_threshold = track_activation_threshold 67 | self.appearance_threshold = appearance_threshold 68 | self.appearance_weight = appearance_weight 69 | self.distance_metric = distance_metric 70 | # Calculate maximum frames without update based on lost_track_buffer and 71 | # frame_rate. This scales the buffer based on the frame rate to ensure 72 | # consistent time-based tracking across different frame rates. 73 | self.maximum_frames_without_update = int( 74 | self.frame_rate / 30.0 * self.lost_track_buffer 75 | ) 76 | 77 | self.trackers: list[DeepSORTKalmanBoxTracker] = [] 78 | 79 | def _get_appearance_distance_matrix( 80 | self, 81 | detection_features: np.ndarray, 82 | ) -> np.ndarray: 83 | """ 84 | Calculate appearance distance matrix between tracks and detections. 85 | 86 | Args: 87 | detection_features (np.ndarray): Features extracted from current detections. 88 | 89 | Returns: 90 | np.ndarray: Appearance distance matrix. 91 | """ 92 | 93 | if len(self.trackers) == 0 or len(detection_features) == 0: 94 | return np.zeros((len(self.trackers), len(detection_features))) 95 | 96 | track_features = np.array([t.get_feature() for t in self.trackers]) 97 | distance_matrix = cdist( 98 | track_features, detection_features, metric=self.distance_metric 99 | ) 100 | distance_matrix = np.clip(distance_matrix, 0, 1) 101 | 102 | return distance_matrix 103 | 104 | def _get_combined_distance_matrix( 105 | self, 106 | iou_matrix: np.ndarray, 107 | appearance_dist_matrix: np.ndarray, 108 | ) -> np.ndarray: 109 | """ 110 | Combine IOU and appearance distances into a single distance matrix. 111 | 112 | Args: 113 | iou_matrix (np.ndarray): IOU matrix between tracks and detections. 114 | appearance_dist_matrix (np.ndarray): Appearance distance matrix. 115 | 116 | Returns: 117 | np.ndarray: Combined distance matrix. 118 | """ 119 | iou_distance: np.ndarray = 1 - iou_matrix 120 | combined_dist = ( 121 | 1 - self.appearance_weight 122 | ) * iou_distance + self.appearance_weight * appearance_dist_matrix 123 | 124 | # Set high distance for IOU below threshold 125 | mask = iou_matrix < self.minimum_iou_threshold 126 | combined_dist[mask] = 1.0 127 | 128 | # Set high distance for appearance above threshold 129 | mask = appearance_dist_matrix > self.appearance_threshold 130 | combined_dist[mask] = 1.0 131 | 132 | return combined_dist 133 | 134 | def _get_associated_indices( 135 | self, 136 | iou_matrix: np.ndarray, 137 | detection_features: np.ndarray, 138 | ) -> tuple[list[tuple[int, int]], set[int], set[int]]: 139 | """ 140 | Associate detections to trackers based on both IOU and appearance. 141 | 142 | Args: 143 | iou_matrix (np.ndarray): IOU matrix between tracks and detections. 144 | detection_features (np.ndarray): Features extracted from current detections. 145 | 146 | Returns: 147 | tuple[list[tuple[int, int]], set[int], set[int]]: Matched indices, 148 | unmatched trackers, unmatched detections. 149 | """ 150 | appearance_dist_matrix = self._get_appearance_distance_matrix( 151 | detection_features 152 | ) 153 | combined_dist = self._get_combined_distance_matrix( 154 | iou_matrix, appearance_dist_matrix 155 | ) 156 | matched_indices = [] 157 | unmatched_trackers = set(range(len(self.trackers))) 158 | unmatched_detections = set(range(len(detection_features))) 159 | 160 | if combined_dist.size > 0: 161 | row_indices, col_indices = np.where(combined_dist < 1.0) 162 | sorted_pairs = sorted( 163 | zip(map(int, row_indices), map(int, col_indices)), 164 | key=lambda x: combined_dist[x[0], x[1]], 165 | ) 166 | 167 | used_rows = set() 168 | used_cols = set() 169 | for row, col in sorted_pairs: 170 | if (row not in used_rows) and (col not in used_cols): 171 | used_rows.add(row) 172 | used_cols.add(col) 173 | matched_indices.append((row, col)) 174 | 175 | unmatched_trackers = unmatched_trackers - {int(row) for row in used_rows} 176 | unmatched_detections = unmatched_detections - { 177 | int(col) for col in used_cols 178 | } 179 | 180 | return matched_indices, unmatched_trackers, unmatched_detections 181 | 182 | def _spawn_new_trackers( 183 | self, 184 | detections: sv.Detections, 185 | detection_boxes: np.ndarray, 186 | detection_features: np.ndarray, 187 | unmatched_detections: set[int], 188 | ): 189 | """ 190 | Create new trackers for unmatched detections with confidence above threshold. 191 | 192 | Args: 193 | detections (sv.Detections): Current detections. 194 | detection_boxes (np.ndarray): Bounding boxes for detections. 195 | detection_features (np.ndarray): Features for detections. 196 | unmatched_detections (set[int]): Indices of unmatched detections. 197 | """ 198 | for detection_idx in unmatched_detections: 199 | if ( 200 | detections.confidence is None 201 | or detection_idx >= len(detections.confidence) 202 | or detections.confidence[detection_idx] 203 | >= self.track_activation_threshold 204 | ): 205 | feature = None 206 | if ( 207 | detection_features is not None 208 | and len(detection_features) > detection_idx 209 | ): 210 | feature = detection_features[detection_idx] 211 | 212 | new_tracker = DeepSORTKalmanBoxTracker( 213 | bbox=detection_boxes[detection_idx], feature=feature 214 | ) 215 | self.trackers.append(new_tracker) 216 | 217 | def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections: 218 | """Updates the tracker state with new detections and appearance features. 219 | 220 | Extracts appearance features, performs Kalman filter prediction, calculates 221 | IOU and appearance distance matrices, associates detections with tracks using 222 | a combined metric, updates matched tracks (position and appearance), and 223 | initializes new tracks for unmatched high-confidence detections. 224 | 225 | Args: 226 | detections (sv.Detections): The latest set of object detections. 227 | frame (np.ndarray): The current video frame, used for extracting 228 | appearance features from detections. 229 | 230 | Returns: 231 | sv.Detections: A copy of the input detections, augmented with assigned 232 | `tracker_id` for each successfully tracked object. Detections not 233 | associated with a track will not have a `tracker_id`. 234 | """ 235 | if len(self.trackers) == 0 and len(detections) == 0: 236 | detections.tracker_id = np.array([], dtype=int) 237 | return detections 238 | 239 | # Convert detections to a (N x 4) array (x1, y1, x2, y2) 240 | detection_boxes = ( 241 | detections.xyxy if len(detections) > 0 else np.array([]).reshape(0, 4) 242 | ) 243 | 244 | # Extract appearance features from the frame and detections 245 | detection_features = self.reid_model.extract_features(detections, frame) 246 | 247 | # Predict new locations for existing trackers 248 | for tracker in self.trackers: 249 | tracker.predict() 250 | 251 | # Build IOU cost matrix between detections and predicted bounding boxes 252 | iou_matrix = get_iou_matrix( 253 | trackers=self.trackers, detection_boxes=detection_boxes 254 | ) 255 | 256 | # Associate detections to trackers based on IOU 257 | matched_indices, _, unmatched_detections = self._get_associated_indices( 258 | iou_matrix, detection_features 259 | ) 260 | 261 | # Update matched trackers with assigned detections 262 | for row, col in matched_indices: 263 | self.trackers[row].update(detection_boxes[col]) 264 | if detection_features is not None and len(detection_features) > col: 265 | self.trackers[row].update_feature(detection_features[col]) 266 | 267 | # Create new trackers for unmatched detections with confidence above threshold 268 | self._spawn_new_trackers( 269 | detections, detection_boxes, detection_features, unmatched_detections 270 | ) 271 | 272 | # Remove dead trackers 273 | self.trackers = get_alive_trackers( 274 | trackers=self.trackers, 275 | maximum_frames_without_update=self.maximum_frames_without_update, 276 | minimum_consecutive_frames=self.minimum_consecutive_frames, 277 | ) 278 | 279 | # Update detections with tracker IDs 280 | updated_detections = update_detections_with_track_ids( 281 | trackers=self.trackers, 282 | detections=detections, 283 | detection_boxes=detection_boxes, 284 | minimum_consecutive_frames=self.minimum_consecutive_frames, 285 | minimum_iou_threshold=self.minimum_iou_threshold, 286 | ) 287 | 288 | return updated_detections 289 | 290 | def reset(self) -> None: 291 | """Resets the tracker's internal state. 292 | 293 | Clears all active tracks and resets the track ID counter. 294 | """ 295 | self.trackers = [] 296 | DeepSORTKalmanBoxTracker.count_id = 0 297 | -------------------------------------------------------------------------------- /trackers/core/reid/__init__.py: -------------------------------------------------------------------------------- 1 | from trackers.log import get_logger 2 | 3 | logger = get_logger(__name__) 4 | 5 | try: 6 | from trackers.core.reid.dataset.base import TripletsDataset 7 | from trackers.core.reid.dataset.market_1501 import get_market1501_dataset 8 | from trackers.core.reid.model import ReIDModel 9 | 10 | __all__ = ["ReIDModel", "TripletsDataset", "get_market1501_dataset"] 11 | except ImportError: 12 | logger.warning( 13 | "ReIDModel dependencies not installed. ReIDModel will not be available. " 14 | "Please run `pip install trackers[reid]` and try again." 15 | ) 16 | pass 17 | -------------------------------------------------------------------------------- /trackers/core/reid/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Optional 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class BaseCallback: 8 | def on_train_batch_start(self, logs: dict, idx: int): 9 | pass 10 | 11 | def on_train_batch_end(self, logs: dict, idx: int): 12 | pass 13 | 14 | def on_train_epoch_end(self, logs: dict, epoch: int): 15 | pass 16 | 17 | def on_validation_batch_start(self, logs: dict, idx: int): 18 | pass 19 | 20 | def on_validation_batch_end(self, logs: dict, idx: int): 21 | pass 22 | 23 | def on_validation_epoch_end(self, logs: dict, epoch: int): 24 | pass 25 | 26 | def on_checkpoint_save(self, checkpoint_path: str, epoch: int): 27 | pass 28 | 29 | def on_end(self): 30 | pass 31 | 32 | 33 | class TensorboardCallback(BaseCallback): 34 | def __init__( 35 | self, 36 | log_dir: Optional[str] = None, 37 | comment: str = "", 38 | purge_step: Optional[Any] = None, 39 | max_queue: int = 10, 40 | flush_secs: int = 120, 41 | filename_suffix: str = "", 42 | ): 43 | from torch.utils.tensorboard import SummaryWriter 44 | 45 | self.writer = SummaryWriter( 46 | log_dir, 47 | comment=comment, 48 | filename_suffix=filename_suffix, 49 | purge_step=purge_step, 50 | max_queue=max_queue, 51 | flush_secs=flush_secs, 52 | ) 53 | 54 | def on_train_batch_end(self, logs: dict, idx: int): 55 | for key, value in logs.items(): 56 | self.writer.add_scalar(key, value, idx) 57 | 58 | def on_train_epoch_end(self, logs: dict, epoch: int): 59 | for key, value in logs.items(): 60 | self.writer.add_scalar(key, value, epoch) 61 | 62 | def on_validation_epoch_end(self, logs: dict, epoch: int): 63 | for key, value in logs.items(): 64 | self.writer.add_scalar(key, value, epoch) 65 | 66 | def on_end(self): 67 | self.writer.flush() 68 | self.writer.close() 69 | 70 | 71 | class WandbCallback(BaseCallback): 72 | def __init__(self, config: dict[str, Any]) -> None: 73 | import wandb 74 | 75 | self.run = wandb.init(config=config) if not wandb.run else wandb.run # type: ignore 76 | 77 | self.run.define_metric("batch/step") 78 | self.run.define_metric("batch/train/loss", step_metric="batch/step") 79 | 80 | self.run.define_metric("epoch") 81 | self.run.define_metric("train/loss", step_metric="epoch") 82 | self.run.define_metric("validation/loss", step_metric="epoch") 83 | 84 | def on_train_batch_end(self, logs: dict, idx: int): 85 | logs["batch/step"] = idx 86 | self.run.log(logs) 87 | 88 | def on_train_epoch_end(self, logs: dict, epoch: int): 89 | logs["epoch"] = epoch 90 | self.run.log(logs) 91 | 92 | def on_validation_epoch_end(self, logs: dict, epoch: int): 93 | logs["epoch"] = epoch 94 | self.run.log(logs) 95 | 96 | def on_checkpoint_save(self, checkpoint_path: str, epoch: int): 97 | self.run.log_model( 98 | path=checkpoint_path, 99 | name=f"checkpoint_{self.run.id}", 100 | aliases=[f"epoch-{epoch}", "latest"], 101 | ) 102 | 103 | def on_end(self): 104 | self.run.finish() 105 | 106 | 107 | class MatplotlibCallback(BaseCallback): 108 | def __init__(self, log_dir: str): 109 | self.log_dir = log_dir 110 | self.train_history: dict[str, list[tuple[int, float]]] = {} 111 | self.validation_history: dict[str, list[tuple[int, float]]] = {} 112 | 113 | def on_train_batch_end(self, logs: dict, idx: int): 114 | for key, value in logs.items(): 115 | self.train_history.setdefault(key, []).append((idx, value)) 116 | 117 | def on_train_epoch_end(self, logs: dict, epoch: int): 118 | for key, value in logs.items(): 119 | self.train_history.setdefault(key, []).append((epoch, value)) 120 | 121 | def on_validation_epoch_end(self, logs: dict, epoch: int): 122 | for key, value in logs.items(): 123 | self.validation_history.setdefault(key, []).append((epoch, value)) 124 | 125 | def _plot_subplot( 126 | self, 127 | ax, 128 | title_prefix: str, 129 | base_metric_name: str, 130 | xlabel: str, 131 | train_keys: list[str], 132 | val_keys: Optional[list[str]] = None, 133 | ): 134 | train_data_points = [] 135 | for key in train_keys: 136 | data = self.train_history.get(key) 137 | if data: # Checks for None and non-empty list 138 | train_data_points = data 139 | break 140 | 141 | val_data_points = [] 142 | if val_keys: 143 | for key in val_keys: 144 | data = self.validation_history.get(key) 145 | if data: # Checks for None and non-empty list 146 | val_data_points = data 147 | break 148 | 149 | plotted_anything = False 150 | if train_data_points: 151 | x, y = zip(*train_data_points) 152 | ax.plot(x, y, label="train", marker=".", markersize=5, linewidth=1) 153 | plotted_anything = True 154 | 155 | if val_data_points: 156 | x_val, y_val = zip(*val_data_points) 157 | ax.plot( 158 | x_val, 159 | y_val, 160 | label="validation", 161 | marker=".", 162 | markersize=5, 163 | linewidth=1, 164 | linestyle="--", 165 | ) 166 | plotted_anything = True 167 | 168 | formatted_base_name = " ".join( 169 | [item.capitalize() for item in base_metric_name.split("_")] 170 | ) 171 | ax.set_title(f"{title_prefix} {formatted_base_name}") 172 | ax.set_xlabel(xlabel) 173 | ax.set_ylabel(formatted_base_name) 174 | 175 | if plotted_anything: 176 | ax.legend() 177 | ax.grid(True, linestyle="--", alpha=0.7) 178 | else: 179 | ax.text( 180 | 0.5, 181 | 0.5, 182 | "No data", 183 | ha="center", 184 | va="center", 185 | transform=ax.transAxes, 186 | fontsize=10, 187 | color="gray", 188 | ) 189 | ax.set_xticks([]) 190 | ax.set_yticks([]) 191 | for spine in ax.spines.values(): 192 | spine.set_edgecolor("lightgray") 193 | 194 | def on_end(self): 195 | if not self.train_history and not self.validation_history: 196 | return 197 | 198 | fig, axes = plt.subplots(2, 2, figsize=(12, 8), squeeze=False) 199 | 200 | # Plot 1: Top-left - Batch Triplet Accuracy 201 | self._plot_subplot( 202 | axes[0, 0], 203 | title_prefix="Batch", 204 | base_metric_name="triplet_accuracy", 205 | xlabel="Batch", 206 | train_keys=["batch/triplet_accuracy", "batch/train/triplet_accuracy"], 207 | val_keys=None, 208 | ) 209 | 210 | # Plot 2: Top-right - Epoch Triplet Accuracy 211 | self._plot_subplot( 212 | axes[0, 1], 213 | title_prefix="Epoch", 214 | base_metric_name="triplet_accuracy", 215 | xlabel="Epoch", 216 | train_keys=["train/triplet_accuracy", "triplet_accuracy"], 217 | val_keys=["validation/triplet_accuracy", "triplet_accuracy"], 218 | ) 219 | 220 | # Plot 3: Bottom-left - Batch Loss 221 | self._plot_subplot( 222 | axes[1, 0], 223 | title_prefix="Batch", 224 | base_metric_name="loss", 225 | xlabel="Batch", 226 | train_keys=["batch/loss", "batch/train/loss"], 227 | val_keys=None, 228 | ) 229 | 230 | # Plot 4: Bottom-right - Epoch Loss 231 | self._plot_subplot( 232 | axes[1, 1], 233 | title_prefix="Epoch", 234 | base_metric_name="loss", 235 | xlabel="Epoch", 236 | train_keys=["train/loss", "loss"], 237 | val_keys=["validation/loss", "loss"], 238 | ) 239 | 240 | os.makedirs(self.log_dir, exist_ok=True) 241 | 242 | plt.tight_layout(pad=2.0) 243 | fig.savefig(os.path.join(self.log_dir, "metrics_plot.png"), dpi=150) 244 | plt.show() 245 | plt.close(fig) 246 | -------------------------------------------------------------------------------- /trackers/core/reid/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/trackers/core/reid/dataset/__init__.py -------------------------------------------------------------------------------- /trackers/core/reid/dataset/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | from pathlib import Path 5 | from typing import Optional, Tuple, Union 6 | 7 | import torch 8 | from PIL import Image 9 | from supervision.dataset.utils import train_test_split 10 | from torch.utils.data import Dataset 11 | from torchvision.transforms import Compose, ToTensor 12 | 13 | from trackers.core.reid.dataset.utils import validate_tracker_id_to_images 14 | 15 | 16 | class TripletsDataset(Dataset): 17 | """A dataset that provides triplets of images for training ReID models. 18 | 19 | This dataset is designed for training models with triplet loss, where each sample 20 | consists of an anchor image, a positive image (same identity as anchor), 21 | and a negative image (different identity from anchor). 22 | 23 | Args: 24 | tracker_id_to_images (dict[str, list[str]]): Dictionary mapping tracker IDs 25 | to lists of image paths 26 | transforms (Optional[Compose]): Optional image transformations to apply 27 | 28 | Attributes: 29 | tracker_id_to_images (dict[str, list[str]]): Dictionary mapping tracker IDs 30 | to lists of image paths 31 | transforms (Optional[Compose]): Optional image transformations to apply 32 | tracker_ids (list[str]): List of all unique tracker IDs in the dataset 33 | """ 34 | 35 | def __init__( 36 | self, 37 | tracker_id_to_images: dict[str, list[str]], 38 | transforms: Optional[Compose] = None, 39 | ): 40 | self.tracker_id_to_images = validate_tracker_id_to_images(tracker_id_to_images) 41 | self.transforms = transforms or ToTensor() 42 | self.tracker_ids = list(self.tracker_id_to_images.keys()) 43 | 44 | @classmethod 45 | def from_image_directories( 46 | cls, 47 | root_directory: str, 48 | transforms: Optional[Compose] = None, 49 | image_extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png"), 50 | ) -> TripletsDataset: 51 | """ 52 | Create TripletsDataset from a directory structured by tracker IDs. 53 | 54 | Args: 55 | root_directory (str): Root directory with tracker folders. 56 | transforms (Optional[Compose]): Optional image transformations. 57 | image_extensions (Tuple[str, ...]): Valid image extensions to load. 58 | 59 | Returns: 60 | TripletsDataset: An initialized dataset. 61 | """ 62 | root_path = Path(root_directory) 63 | tracker_id_to_images = {} 64 | 65 | for tracker_path in sorted(root_path.iterdir()): 66 | if not tracker_path.is_dir(): 67 | continue 68 | 69 | image_paths = sorted( 70 | [ 71 | str(image_path) 72 | for image_path in tracker_path.glob("*") 73 | if image_path.suffix.lower() in image_extensions 74 | and image_path.is_file() 75 | ] 76 | ) 77 | 78 | if image_paths: 79 | tracker_id_to_images[tracker_path.name] = image_paths 80 | 81 | return cls( 82 | tracker_id_to_images=tracker_id_to_images, 83 | transforms=transforms, 84 | ) 85 | 86 | def __len__(self) -> int: 87 | """ 88 | Return the number of unique tracker IDs (identities) in the dataset. 89 | 90 | Returns: 91 | int: The total number of unique identities (tracker IDs) available for 92 | sampling triplets. 93 | """ 94 | return len(self.tracker_ids) 95 | 96 | def _load_and_transform_image(self, image_path: str) -> torch.Tensor: 97 | image = Image.open(image_path).convert("RGB") 98 | if self.transforms: 99 | image = self.transforms(image) 100 | return image 101 | 102 | def _get_triplet_image_paths(self, tracker_id: str) -> Tuple[str, str, str]: 103 | tracker_id_image_paths = self.tracker_id_to_images[tracker_id] 104 | 105 | anchor_image_path, positive_image_path = random.sample( # nosec B311 106 | tracker_id_image_paths, 2 107 | ) 108 | 109 | negative_candidates = [tid for tid in self.tracker_ids if tid != tracker_id] 110 | negative_tracker_id = random.choice(negative_candidates) # nosec B311 111 | 112 | negative_image_path = random.choice( # nosec B311 113 | self.tracker_id_to_images[negative_tracker_id] 114 | ) 115 | 116 | return anchor_image_path, positive_image_path, negative_image_path 117 | 118 | def __getitem__( 119 | self, index: int 120 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 121 | """ 122 | Retrieve a random triplet (anchor, positive, negative) of images for a given 123 | identity. 124 | 125 | For the tracker ID at the given index, samples two different images as the 126 | anchor and positive (same identity), and one image from a different tracker ID 127 | as the negative (different identity). 128 | 129 | Args: 130 | index (int): Index of the tracker ID (identity) to sample the triplet from. 131 | 132 | Returns: 133 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 134 | A tuple containing the anchor, positive, and negative image tensors. 135 | """ 136 | tracker_id = self.tracker_ids[index] 137 | 138 | anchor_image_path, positive_image_path, negative_image_path = ( 139 | self._get_triplet_image_paths(tracker_id) 140 | ) 141 | 142 | anchor_image = self._load_and_transform_image(anchor_image_path) 143 | positive_image = self._load_and_transform_image(positive_image_path) 144 | negative_image = self._load_and_transform_image(negative_image_path) 145 | 146 | return anchor_image, positive_image, negative_image 147 | 148 | def split( 149 | self, 150 | split_ratio: float = 0.8, 151 | random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, 152 | shuffle: bool = True, 153 | ) -> Tuple[TripletsDataset, TripletsDataset]: 154 | train_tracker_id_to_images, validation_tracker_id_to_images = train_test_split( 155 | list(self.tracker_id_to_images.keys()), 156 | train_ratio=split_ratio, 157 | random_state=random_state, 158 | shuffle=shuffle, 159 | ) 160 | train_tracker_id_to_images = { 161 | tracker_id: self.tracker_id_to_images[tracker_id] 162 | for tracker_id in train_tracker_id_to_images 163 | } 164 | validation_tracker_id_to_images = { 165 | tracker_id: self.tracker_id_to_images[tracker_id] 166 | for tracker_id in validation_tracker_id_to_images 167 | } 168 | return ( 169 | TripletsDataset(train_tracker_id_to_images, self.transforms), 170 | TripletsDataset(validation_tracker_id_to_images, self.transforms), 171 | ) 172 | -------------------------------------------------------------------------------- /trackers/core/reid/dataset/market_1501.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | from torchvision.transforms import Compose 7 | 8 | from trackers.core.reid.dataset.base import TripletsDataset 9 | 10 | 11 | def parse_market1501_dataset(data_dir: str) -> Dict[str, List[str]]: 12 | """Parse the [Market1501 dataset](https://paperswithcode.com/dataset/market-1501) 13 | to create a dictionary mapping tracker IDs to lists of image paths. 14 | 15 | Args: 16 | data_dir (str): The path to the Market1501 dataset. 17 | 18 | Returns: 19 | Dict[str, List[str]]: A dictionary mapping tracker IDs to lists of image paths. 20 | """ 21 | image_files = glob.glob(os.path.join(data_dir, "*.jpg")) 22 | tracker_id_to_images = defaultdict(list) 23 | for image_file in image_files: 24 | tracker_id = os.path.basename(image_file).split("_")[0] 25 | tracker_id_to_images[tracker_id].append(image_file) 26 | return dict(tracker_id_to_images) 27 | 28 | 29 | def get_market1501_dataset( 30 | data_dir: str, 31 | split_ratio: Optional[float] = None, 32 | random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, 33 | shuffle: bool = True, 34 | transforms: Optional[Compose] = None, 35 | ) -> Union[TripletsDataset, Tuple[TripletsDataset, TripletsDataset]]: 36 | """Get the [Market1501 dataset](https://paperswithcode.com/dataset/market-1501). 37 | 38 | Args: 39 | data_dir (str): The path to the bounding box train/test directory of the 40 | [Market1501 dataset](https://paperswithcode.com/dataset/market-1501). 41 | split_ratio (Optional[float]): The ratio of the dataset to split into training 42 | and validation sets. If `None`, the dataset is returned as a single 43 | `TripletsDataset` object, otherwise the dataset is split into a tuple of 44 | training and validation `TripletsDataset` objects. 45 | random_state (Optional[Union[int, float, str, bytes, bytearray]]): The random 46 | state to use for the split. 47 | shuffle (bool): Whether to shuffle the dataset. 48 | transforms (Optional[Compose]): The transforms to apply to the dataset. 49 | 50 | Returns: 51 | Tuple[TripletsDataset, TripletsDataset]: A tuple of training and validation 52 | `TripletsDataset` objects. 53 | """ 54 | tracker_id_to_images = parse_market1501_dataset(data_dir) 55 | dataset = TripletsDataset(tracker_id_to_images, transforms) 56 | if split_ratio is not None: 57 | train_dataset, validation_dataset = dataset.split( 58 | split_ratio=split_ratio, random_state=random_state, shuffle=shuffle 59 | ) 60 | return train_dataset, validation_dataset 61 | return dataset 62 | -------------------------------------------------------------------------------- /trackers/core/reid/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from trackers.log import get_logger 2 | 3 | logger = get_logger(__name__) 4 | 5 | 6 | def validate_tracker_id_to_images( 7 | tracker_id_to_images: dict[str, list[str]], 8 | ) -> dict[str, list[str]]: 9 | """Validates a dictionary that maps tracker IDs to lists of image paths for the 10 | `TripletsDataset` for training ReID models using triplet loss. 11 | 12 | Args: 13 | tracker_id_to_images (dict[str, list[str]]): The tracker ID to images 14 | dictionary. 15 | 16 | Returns: 17 | dict[str, list[str]]: The validated tracker ID to images dictionary. 18 | """ 19 | valid_tracker_ids = {} 20 | for tracker_id, image_paths in tracker_id_to_images.items(): 21 | if len(image_paths) < 2: 22 | logger.warning( 23 | f"Tracker ID '{tracker_id}' has less than 2 images. " 24 | f"Skipping this tracker ID." 25 | ) 26 | else: 27 | valid_tracker_ids[tracker_id] = image_paths 28 | 29 | if len(valid_tracker_ids) < 2: 30 | raise ValueError( 31 | "Tracker ID to images dictionary must contain at least 2 items " 32 | "to select negative samples." 33 | ) 34 | 35 | return valid_tracker_ids 36 | -------------------------------------------------------------------------------- /trackers/core/reid/metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class TripletMetric(ABC): 8 | @abstractmethod 9 | def update( 10 | self, 11 | anchor_embed: torch.Tensor, 12 | positive_embed: torch.Tensor, 13 | negative_embed: torch.Tensor, 14 | ) -> None: 15 | pass 16 | 17 | @abstractmethod 18 | def compute(self) -> float: 19 | pass 20 | 21 | @abstractmethod 22 | def reset(self) -> None: 23 | pass 24 | 25 | 26 | class TripletAccuracyMetric(TripletMetric): 27 | """ 28 | Calculates the triplet accuracy using pairwise distance. 29 | Accuracy is defined as the proportion of triplets where the distance 30 | between the anchor and positive embedding is less than the distance 31 | between the anchor and negative embedding. 32 | """ 33 | 34 | def __init__(self): 35 | self.correct = 0 36 | self.total = 0 37 | 38 | def __str__(self): 39 | return "triplet_accuracy" 40 | 41 | def update( 42 | self, 43 | anchor_embed: torch.Tensor, 44 | positive_embed: torch.Tensor, 45 | negative_embed: torch.Tensor, 46 | ) -> None: 47 | """ 48 | Update the metric with a batch of embeddings. 49 | 50 | Args: 51 | anchor_embed (torch.Tensor): Embeddings of the anchor samples. 52 | positive_embed (torch.Tensor): Embeddings of the positive samples. 53 | negative_embed (torch.Tensor): Embeddings of the negative samples. 54 | """ 55 | dist_ap = F.pairwise_distance(anchor_embed, positive_embed, p=2) 56 | dist_an = F.pairwise_distance(anchor_embed, negative_embed, p=2) 57 | self.correct += torch.sum(dist_ap < dist_an).item() 58 | self.total += anchor_embed.size(0) 59 | 60 | def compute(self) -> float: 61 | if self.total == 0: 62 | return 0.0 63 | return self.correct / self.total 64 | 65 | def reset(self) -> None: 66 | self.correct = 0 67 | self.total = 0 68 | -------------------------------------------------------------------------------- /trackers/core/reid/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | from typing import Any, Callable, Optional, Union 6 | 7 | import numpy as np 8 | import PIL 9 | import supervision as sv 10 | import timm 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from safetensors.torch import save_file 15 | from timm.data import resolve_data_config 16 | from timm.data.transforms_factory import create_transform 17 | from torch.utils.data import DataLoader 18 | from torchvision.transforms import Compose, ToPILImage 19 | from tqdm.auto import tqdm 20 | 21 | from trackers.core.reid.callbacks import BaseCallback 22 | from trackers.core.reid.metrics import ( 23 | TripletAccuracyMetric, 24 | TripletMetric, 25 | ) 26 | from trackers.log import get_logger 27 | from trackers.utils.torch_utils import load_safetensors_checkpoint, parse_device_spec 28 | 29 | logger = get_logger(__name__) 30 | 31 | 32 | def _initialize_reid_model_from_timm( 33 | cls, 34 | model_name_or_checkpoint_path: str, 35 | device: Optional[str] = "auto", 36 | get_pooled_features: bool = True, 37 | **kwargs, 38 | ): 39 | if model_name_or_checkpoint_path not in timm.list_models( 40 | filter=model_name_or_checkpoint_path, pretrained=True 41 | ): 42 | probable_model_name_list = timm.list_models( 43 | f"*{model_name_or_checkpoint_path}*", pretrained=True 44 | ) 45 | if len(probable_model_name_list) == 0: 46 | raise ValueError( 47 | f"Model {model_name_or_checkpoint_path} not found in timm. " 48 | + "Please check the model name and try again." 49 | ) 50 | logger.warning( 51 | f"Model {model_name_or_checkpoint_path} not found in timm. " 52 | + f"Using {probable_model_name_list[0]} instead." 53 | ) 54 | model_name_or_checkpoint_path = probable_model_name_list[0] 55 | if not get_pooled_features: 56 | kwargs["global_pool"] = "" 57 | model = timm.create_model( 58 | model_name_or_checkpoint_path, pretrained=True, num_classes=0, **kwargs 59 | ) 60 | config = resolve_data_config(model.pretrained_cfg) 61 | transforms = create_transform(**config) 62 | model_metadata = { 63 | "model_name_or_checkpoint_path": model_name_or_checkpoint_path, 64 | "get_pooled_features": get_pooled_features, 65 | "kwargs": kwargs, 66 | } 67 | return cls(model, device, transforms, model_metadata) 68 | 69 | 70 | def _initialize_reid_model_from_checkpoint(cls, checkpoint_path: str): 71 | state_dict, config = load_safetensors_checkpoint(checkpoint_path) 72 | reid_model_instance = _initialize_reid_model_from_timm( 73 | cls, **config["model_metadata"] 74 | ) 75 | if config["projection_dimension"]: 76 | reid_model_instance._add_projection_layer( 77 | projection_dimension=config["projection_dimension"] 78 | ) 79 | for k, v in state_dict.items(): 80 | state_dict[k].to(reid_model_instance.device) 81 | reid_model_instance.backbone_model.load_state_dict(state_dict) 82 | return reid_model_instance 83 | 84 | 85 | class ReIDModel: 86 | """ 87 | A ReID model that is used to extract features from detection crops for trackers 88 | that utilize appearance features. 89 | 90 | Args: 91 | backbone_model (nn.Module): The torch model to use as the backbone. 92 | device (Optional[str]): The device to run the model on. 93 | transforms (Optional[Union[Callable, list[Callable]]]): The transforms to 94 | apply to the input images. 95 | model_metadata (dict[str, Any]): Metadata about the model architecture. 96 | """ 97 | 98 | def __init__( 99 | self, 100 | backbone_model: nn.Module, 101 | device: Optional[str] = "auto", 102 | transforms: Optional[Union[Callable, list[Callable]]] = None, 103 | model_metadata: dict[str, Any] = {}, 104 | ): 105 | self.backbone_model = backbone_model 106 | self.device = parse_device_spec(device or "auto") 107 | self.backbone_model.to(self.device) 108 | self.backbone_model.eval() 109 | self.train_transforms = ( 110 | (Compose(*transforms) if isinstance(transforms, list) else transforms) 111 | if transforms is not None 112 | else None 113 | ) 114 | self.inference_transforms = Compose( 115 | [ToPILImage(), *transforms] 116 | if isinstance(transforms, list) 117 | else [ToPILImage(), transforms] 118 | ) 119 | self.model_metadata = model_metadata 120 | 121 | @classmethod 122 | def from_timm( 123 | cls, 124 | model_name_or_checkpoint_path: str, 125 | device: Optional[str] = "auto", 126 | get_pooled_features: bool = True, 127 | **kwargs, 128 | ) -> ReIDModel: 129 | """ 130 | Create a `ReIDModel` with a [timm](https://huggingface.co/docs/timm) 131 | model as the backbone. 132 | 133 | Args: 134 | model_name_or_checkpoint_path (str): Name of the timm model to use or 135 | path to a safetensors checkpoint. If the exact model name is not 136 | found, the closest match from `timm.list_models` will be used. 137 | device (str): Device to run the model on. 138 | get_pooled_features (bool): Whether to get the pooled features from the 139 | model or not. 140 | **kwargs: Additional keyword arguments to pass to 141 | [`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model). 142 | 143 | Returns: 144 | ReIDModel: A new instance of `ReIDModel`. 145 | """ 146 | if os.path.exists(model_name_or_checkpoint_path): 147 | return _initialize_reid_model_from_checkpoint( 148 | cls, model_name_or_checkpoint_path 149 | ) 150 | else: 151 | return _initialize_reid_model_from_timm( 152 | cls, 153 | model_name_or_checkpoint_path, 154 | device, 155 | get_pooled_features, 156 | **kwargs, 157 | ) 158 | 159 | def extract_features( 160 | self, detections: sv.Detections, frame: Union[np.ndarray, PIL.Image.Image] 161 | ) -> np.ndarray: 162 | """ 163 | Extract features from detection crops in the frame. 164 | 165 | Args: 166 | detections (sv.Detections): Detections from which to extract features. 167 | frame (np.ndarray or PIL.Image.Image): The input frame. 168 | 169 | Returns: 170 | np.ndarray: Extracted features for each detection. 171 | """ 172 | if len(detections) == 0: 173 | return np.array([]) 174 | 175 | if isinstance(frame, PIL.Image.Image): 176 | frame = np.array(frame) 177 | 178 | features = [] 179 | with torch.inference_mode(): 180 | for box in detections.xyxy: 181 | crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)]) 182 | tensor = self.inference_transforms(crop).unsqueeze(0).to(self.device) 183 | feature = ( 184 | torch.squeeze(self.backbone_model(tensor)).cpu().numpy().flatten() 185 | ) 186 | features.append(feature) 187 | 188 | return np.array(features) 189 | 190 | def _add_projection_layer( 191 | self, projection_dimension: Optional[int] = None, freeze_backbone: bool = False 192 | ): 193 | """ 194 | Perform model surgery to add a projection layer to the model and freeze the 195 | backbone if specified. The backbone is only frozen if `projection_dimension` 196 | is specified. 197 | 198 | Args: 199 | projection_dimension (Optional[int]): The dimension of the projection layer. 200 | freeze_backbone (bool): Whether to freeze the backbone of the model during 201 | training. 202 | """ 203 | if projection_dimension is not None: 204 | # Freeze backbone only if specified and projection_dimension is mentioned 205 | if freeze_backbone: 206 | for param in self.backbone_model.parameters(): 207 | param.requires_grad = False 208 | 209 | # Add projection layer if projection_dimension is specified 210 | self.backbone_model = nn.Sequential( 211 | self.backbone_model, 212 | nn.Linear(self.backbone_model.num_features, projection_dimension), 213 | ) 214 | self.backbone_model.to(self.device) 215 | 216 | def _train_step( 217 | self, 218 | anchor_image: torch.Tensor, 219 | positive_image: torch.Tensor, 220 | negative_image: torch.Tensor, 221 | metrics_list: list[TripletMetric], 222 | ) -> dict[str, float]: 223 | """ 224 | Perform a single training step. 225 | 226 | Args: 227 | anchor_image (torch.Tensor): The anchor image. 228 | positive_image (torch.Tensor): The positive image. 229 | negative_image (torch.Tensor): The negative image. 230 | metrics_list (list[Metric]): The list of metrics to update. 231 | """ 232 | self.optimizer.zero_grad() 233 | anchor_image_features = self.backbone_model(anchor_image) 234 | positive_image_features = self.backbone_model(positive_image) 235 | negative_image_features = self.backbone_model(negative_image) 236 | 237 | loss = self.criterion( 238 | anchor_image_features, 239 | positive_image_features, 240 | negative_image_features, 241 | ) 242 | loss.backward() 243 | self.optimizer.step() 244 | 245 | # Update metrics 246 | for metric in metrics_list: 247 | metric.update( 248 | anchor_image_features.detach(), 249 | positive_image_features.detach(), 250 | negative_image_features.detach(), 251 | ) 252 | 253 | train_logs = {"train/loss": loss.item()} 254 | for metric in metrics_list: 255 | train_logs[f"train/{metric!s}"] = metric.compute() 256 | 257 | return train_logs 258 | 259 | def _validation_step( 260 | self, 261 | anchor_image: torch.Tensor, 262 | positive_image: torch.Tensor, 263 | negative_image: torch.Tensor, 264 | metrics_list: list[TripletMetric], 265 | ) -> dict[str, float]: 266 | """ 267 | Perform a single validation step. 268 | 269 | Args: 270 | anchor_image (torch.Tensor): The anchor image. 271 | positive_image (torch.Tensor): The positive image. 272 | negative_image (torch.Tensor): The negative image. 273 | metrics_list (list[Metric]): The list of metrics to update. 274 | """ 275 | with torch.inference_mode(): 276 | anchor_image_features = self.backbone_model(anchor_image) 277 | positive_image_features = self.backbone_model(positive_image) 278 | negative_image_features = self.backbone_model(negative_image) 279 | 280 | loss = self.criterion( 281 | anchor_image_features, 282 | positive_image_features, 283 | negative_image_features, 284 | ) 285 | 286 | # Update metrics 287 | for metric in metrics_list: 288 | metric.update( 289 | anchor_image_features.detach(), 290 | positive_image_features.detach(), 291 | negative_image_features.detach(), 292 | ) 293 | 294 | validation_logs = {"validation/loss": loss.item()} 295 | for metric in metrics_list: 296 | validation_logs[f"validation/{metric!s}"] = metric.compute() 297 | 298 | return validation_logs 299 | 300 | def train( 301 | self, 302 | train_loader: DataLoader, 303 | epochs: int, 304 | validation_loader: Optional[DataLoader] = None, 305 | projection_dimension: Optional[int] = None, 306 | freeze_backbone: bool = False, 307 | learning_rate: float = 5e-5, 308 | weight_decay: float = 0.0, 309 | triplet_margin: float = 1.0, 310 | random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, 311 | checkpoint_interval: Optional[int] = None, 312 | log_dir: str = "logs", 313 | log_to_matplotlib: bool = False, 314 | log_to_tensorboard: bool = False, 315 | log_to_wandb: bool = False, 316 | ) -> None: 317 | """ 318 | Train/fine-tune the ReID model. 319 | 320 | Args: 321 | train_loader (DataLoader): The training data loader. 322 | epochs (int): The number of epochs to train the model. 323 | validation_loader (Optional[DataLoader]): The validation data loader. 324 | projection_dimension (Optional[int]): The dimension of the projection layer. 325 | freeze_backbone (bool): Whether to freeze the backbone of the model. The 326 | backbone is only frozen if `projection_dimension` is specified. 327 | learning_rate (float): The learning rate to use for the optimizer. 328 | weight_decay (float): The weight decay to use for the optimizer. 329 | triplet_margin (float): The margin to use for the triplet loss. 330 | random_state (Optional[Union[int, float, str, bytes, bytearray]]): The 331 | random state to use for the training. 332 | checkpoint_interval (Optional[int]): The interval to save checkpoints. 333 | log_dir (str): The directory to save logs. 334 | log_to_matplotlib (bool): Whether to log to matplotlib. 335 | log_to_tensorboard (bool): Whether to log to tensorboard. 336 | log_to_wandb (bool): Whether to log to wandb. If `checkpoint_interval` is 337 | specified, the model will be logged to wandb as well. 338 | Project and entity name should be set using the environment variables 339 | `WANDB_PROJECT` and `WANDB_ENTITY`. For more details, refer to 340 | [wandb environment variables](https://docs.wandb.ai/guides/track/environment-variables). 341 | """ 342 | os.makedirs(log_dir, exist_ok=True) 343 | os.makedirs(os.path.join(log_dir, "checkpoints"), exist_ok=True) 344 | os.makedirs(os.path.join(log_dir, "tensorboard_logs"), exist_ok=True) 345 | 346 | if random_state is not None: 347 | torch.manual_seed(random_state) 348 | 349 | self._add_projection_layer(projection_dimension, freeze_backbone) 350 | 351 | # Initialize optimizer, criterion and metrics 352 | self.optimizer = optim.Adam( 353 | self.backbone_model.parameters(), 354 | lr=learning_rate, 355 | weight_decay=weight_decay, 356 | ) 357 | self.criterion = nn.TripletMarginLoss(margin=triplet_margin) 358 | metrics_list: list[TripletMetric] = [TripletAccuracyMetric()] 359 | 360 | config = { 361 | "epochs": epochs, 362 | "learning_rate": learning_rate, 363 | "weight_decay": weight_decay, 364 | "random_state": random_state, 365 | "projection_dimension": projection_dimension, 366 | "freeze_backbone": freeze_backbone, 367 | "triplet_margin": triplet_margin, 368 | "model_metadata": self.model_metadata, 369 | } 370 | 371 | # Initialize callbacks 372 | callbacks: list[BaseCallback] = [] 373 | if log_to_matplotlib: 374 | try: 375 | from trackers.core.reid.callbacks import MatplotlibCallback 376 | 377 | callbacks.append(MatplotlibCallback(log_dir=log_dir)) 378 | except (ImportError, AttributeError) as e: 379 | logger.error( 380 | "Metric logging dependencies are not installed. " 381 | "Please install it using `pip install trackers[metrics]`.", 382 | ) 383 | raise e 384 | if log_to_tensorboard: 385 | try: 386 | from trackers.core.reid.callbacks import TensorboardCallback 387 | 388 | callbacks.append( 389 | TensorboardCallback( 390 | log_dir=os.path.join(log_dir, "tensorboard_logs") 391 | ) 392 | ) 393 | except (ImportError, AttributeError) as e: 394 | logger.error( 395 | "Metric logging dependencies are not installed. " 396 | "Please install it using `pip install trackers[metrics]`." 397 | ) 398 | raise e 399 | 400 | if log_to_wandb: 401 | try: 402 | from trackers.core.reid.callbacks import WandbCallback 403 | 404 | callbacks.append(WandbCallback(config=config)) 405 | except (ImportError, AttributeError) as e: 406 | logger.error( 407 | "Metric logging dependencies are not installed. " 408 | "Please install it using `pip install trackers[metrics]`." 409 | ) 410 | raise e 411 | 412 | # Training loop over epochs 413 | for epoch in tqdm(range(epochs), desc="Training"): 414 | # Reset metrics at the start of each epoch 415 | for metric in metrics_list: 416 | metric.reset() 417 | 418 | # Training loop over batches 419 | accumulated_train_logs: dict[str, Union[float, int]] = {} 420 | for idx, data in tqdm( 421 | enumerate(train_loader), 422 | total=len(train_loader), 423 | desc=f"Training Epoch {epoch + 1}/{epochs}", 424 | leave=False, 425 | ): 426 | anchor_image, positive_image, negative_image = data 427 | if self.train_transforms is not None: 428 | anchor_image = self.train_transforms(anchor_image) 429 | positive_image = self.train_transforms(positive_image) 430 | negative_image = self.train_transforms(negative_image) 431 | 432 | anchor_image = anchor_image.to(self.device) 433 | positive_image = positive_image.to(self.device) 434 | negative_image = negative_image.to(self.device) 435 | 436 | if callbacks: 437 | for callback in callbacks: 438 | callback.on_train_batch_start( 439 | {}, epoch * len(train_loader) + idx 440 | ) 441 | 442 | train_logs = self._train_step( 443 | anchor_image, positive_image, negative_image, metrics_list 444 | ) 445 | 446 | for key, value in train_logs.items(): 447 | accumulated_train_logs[key] = ( 448 | accumulated_train_logs.get(key, 0) + value 449 | ) 450 | 451 | if callbacks: 452 | for callback in callbacks: 453 | for key, value in train_logs.items(): 454 | callback.on_train_batch_end( 455 | {f"batch/{key}": value}, epoch * len(train_loader) + idx 456 | ) 457 | 458 | for key, value in accumulated_train_logs.items(): 459 | accumulated_train_logs[key] = value / len(train_loader) 460 | 461 | # Compute and add training metrics to logs 462 | for metric in metrics_list: 463 | accumulated_train_logs[f"train/{metric!s}"] = metric.compute() 464 | # Metrics are reset at the start of the next epoch or before validation 465 | 466 | if callbacks: 467 | for callback in callbacks: 468 | callback.on_train_epoch_end(accumulated_train_logs, epoch) 469 | 470 | # Validation loop over batches 471 | accumulated_validation_logs: dict[str, Union[float, int]] = {} 472 | if validation_loader is not None: 473 | # Reset metrics for validation 474 | for metric in metrics_list: 475 | metric.reset() 476 | for idx, data in tqdm( 477 | enumerate(validation_loader), 478 | total=len(validation_loader), 479 | desc=f"Validation Epoch {epoch + 1}/{epochs}", 480 | leave=False, 481 | ): 482 | if callbacks: 483 | for callback in callbacks: 484 | callback.on_validation_batch_start( 485 | {}, epoch * len(train_loader) + idx 486 | ) 487 | 488 | anchor_image, positive_image, negative_image = data 489 | if self.train_transforms is not None: 490 | anchor_image = self.train_transforms(anchor_image) 491 | positive_image = self.train_transforms(positive_image) 492 | negative_image = self.train_transforms(negative_image) 493 | 494 | anchor_image = anchor_image.to(self.device) 495 | positive_image = positive_image.to(self.device) 496 | negative_image = negative_image.to(self.device) 497 | 498 | validation_logs = self._validation_step( 499 | anchor_image, positive_image, negative_image, metrics_list 500 | ) 501 | 502 | for key, value in validation_logs.items(): 503 | accumulated_validation_logs[key] = ( 504 | accumulated_validation_logs.get(key, 0) + value 505 | ) 506 | 507 | if callbacks: 508 | for callback in callbacks: 509 | for key, value in validation_logs.items(): 510 | callback.on_validation_batch_end( 511 | {f"batch/{key}": value}, 512 | epoch * len(train_loader) + idx, 513 | ) 514 | 515 | for key, value in accumulated_validation_logs.items(): 516 | accumulated_validation_logs[key] = value / len(validation_loader) 517 | 518 | # Compute and add validation metrics to logs 519 | for metric in metrics_list: 520 | accumulated_validation_logs[f"validation/{metric!s}"] = ( 521 | metric.compute() 522 | ) 523 | # Metrics will be reset at the start of the next training epoch loop 524 | 525 | if callbacks: 526 | for callback in callbacks: 527 | callback.on_validation_epoch_end(accumulated_validation_logs, epoch) 528 | 529 | # Save checkpoint 530 | if ( 531 | checkpoint_interval is not None 532 | and (epoch + 1) % checkpoint_interval == 0 533 | ): 534 | state_dict = self.backbone_model.state_dict() 535 | checkpoint_path = os.path.join( 536 | log_dir, "checkpoints", f"reid_model_{epoch + 1}.safetensors" 537 | ) 538 | save_file( 539 | state_dict, 540 | checkpoint_path, 541 | metadata={"config": json.dumps(config), "format": "pt"}, 542 | ) 543 | if callbacks: 544 | for callback in callbacks: 545 | callback.on_checkpoint_save(checkpoint_path, epoch + 1) 546 | 547 | if callbacks: 548 | for callback in callbacks: 549 | callback.on_end() 550 | -------------------------------------------------------------------------------- /trackers/core/sort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roboflow/trackers/f42a7a79f7869e90a11cc4237ccc719fdf028b33/trackers/core/sort/__init__.py -------------------------------------------------------------------------------- /trackers/core/sort/kalman_box_tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.typing import NDArray 3 | 4 | 5 | class SORTKalmanBoxTracker: 6 | """ 7 | The `SORTKalmanBoxTracker` class represents the internals of a single 8 | tracked object (bounding box), with a Kalman filter to predict and update 9 | its position. 10 | 11 | Attributes: 12 | tracker_id (int): Unique identifier for the tracker. 13 | number_of_successful_updates (int): Number of times the object has been 14 | updated successfully. 15 | time_since_update (int): Number of frames since the last update. 16 | state (np.ndarray): State vector of the bounding box. 17 | F (np.ndarray): State transition matrix. 18 | H (np.ndarray): Measurement matrix. 19 | Q (np.ndarray): Process noise covariance matrix. 20 | R (np.ndarray): Measurement noise covariance matrix. 21 | P (np.ndarray): Error covariance matrix. 22 | count_id (int): Class variable to assign unique IDs to each tracker. 23 | 24 | Args: 25 | bbox (np.ndarray): Initial bounding box in the form [x1, y1, x2, y2]. 26 | """ 27 | 28 | count_id: int = 0 29 | state: NDArray[np.float32] 30 | F: NDArray[np.float32] 31 | H: NDArray[np.float32] 32 | Q: NDArray[np.float32] 33 | R: NDArray[np.float32] 34 | P: NDArray[np.float32] 35 | 36 | @classmethod 37 | def get_next_tracker_id(cls) -> int: 38 | next_id = cls.count_id 39 | cls.count_id += 1 40 | return next_id 41 | 42 | def __init__(self, bbox: NDArray[np.float64]) -> None: 43 | # Initialize with a temporary ID of -1 44 | # Will be assigned a real ID when the track is considered mature 45 | self.tracker_id = -1 46 | 47 | # Number of hits indicates how many times the object has been 48 | # updated successfully 49 | self.number_of_successful_updates = 1 50 | # Number of frames since the last update 51 | self.time_since_update = 0 52 | 53 | # For simplicity, we keep a small state vector: 54 | # (x, y, x2, y2, vx, vy, vx2, vy2). 55 | # We'll store the bounding box in "self.state" 56 | self.state = np.zeros((8, 1), dtype=np.float32) 57 | 58 | # Initialize state directly from the first detection 59 | bbox_float: NDArray[np.float32] = bbox.astype(np.float32) 60 | self.state[0, 0] = bbox_float[0] 61 | self.state[1, 0] = bbox_float[1] 62 | self.state[2, 0] = bbox_float[2] 63 | self.state[3, 0] = bbox_float[3] 64 | 65 | # Basic constant velocity model 66 | self._initialize_kalman_filter() 67 | 68 | def _initialize_kalman_filter(self) -> None: 69 | """ 70 | Sets up the matrices for the Kalman filter. 71 | """ 72 | # State transition matrix (F): 8x8 73 | # We assume a constant velocity model. Positions are incremented by 74 | # velocity each step. 75 | self.F = np.eye(8, dtype=np.float32) 76 | for i in range(4): 77 | self.F[i, i + 4] = 1.0 78 | 79 | # Measurement matrix (H): we directly measure x1, y1, x2, y2 80 | self.H = np.eye(4, 8, dtype=np.float32) # 4x8 81 | 82 | # Process covariance matrix (Q) 83 | self.Q = np.eye(8, dtype=np.float32) * 0.01 84 | 85 | # Measurement covariance (R): noise in detection 86 | self.R = np.eye(4, dtype=np.float32) * 0.1 87 | 88 | # Error covariance matrix (P) 89 | self.P = np.eye(8, dtype=np.float32) 90 | 91 | def predict(self) -> None: 92 | """ 93 | Predict the next state of the bounding box (applies the state transition). 94 | """ 95 | # Predict state 96 | self.state = (self.F @ self.state).astype(np.float32) 97 | # Predict error covariance 98 | self.P = (self.F @ self.P @ self.F.T + self.Q).astype(np.float32) 99 | 100 | # Increase time since update 101 | self.time_since_update += 1 102 | 103 | def update(self, bbox: NDArray[np.float64]) -> None: 104 | """ 105 | Updates the state with a new detected bounding box. 106 | 107 | Args: 108 | bbox (np.ndarray): Detected bounding box in the form [x1, y1, x2, y2]. 109 | """ 110 | self.time_since_update = 0 111 | self.number_of_successful_updates += 1 112 | 113 | # Kalman Gain 114 | S: NDArray[np.float32] = self.H @ self.P @ self.H.T + self.R 115 | K: NDArray[np.float32] = (self.P @ self.H.T @ np.linalg.inv(S)).astype( 116 | np.float32 117 | ) 118 | 119 | # Residual 120 | measurement: NDArray[np.float32] = bbox.reshape((4, 1)).astype(np.float32) 121 | y: NDArray[np.float32] = ( 122 | measurement - self.H @ self.state 123 | ) # y should be float32 (4,1) 124 | 125 | # Update state 126 | self.state = (self.state + K @ y).astype(np.float32) 127 | 128 | # Update covariance 129 | identity_matrix: NDArray[np.float32] = np.eye(8, dtype=np.float32) 130 | self.P = ((identity_matrix - K @ self.H) @ self.P).astype(np.float32) 131 | 132 | def get_state_bbox(self) -> NDArray[np.float32]: 133 | """ 134 | Returns the current bounding box estimate from the state vector. 135 | 136 | Returns: 137 | np.ndarray: The bounding box [x1, y1, x2, y2] 138 | """ 139 | return self.state[:4, 0].flatten().astype(np.float32) 140 | -------------------------------------------------------------------------------- /trackers/core/sort/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import supervision as sv 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | from trackers.core.base import BaseTracker 6 | from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker 7 | from trackers.utils.sort_utils import ( 8 | get_alive_trackers, 9 | get_iou_matrix, 10 | update_detections_with_track_ids, 11 | ) 12 | 13 | 14 | class SORTTracker(BaseTracker): 15 | """Implements SORT (Simple Online and Realtime Tracking). 16 | 17 | SORT is a pragmatic approach to multiple object tracking with a focus on 18 | simplicity and speed. It uses a Kalman filter for motion prediction and the 19 | Hungarian algorithm or simple IOU matching for data association. 20 | 21 | Args: 22 | lost_track_buffer (int): Number of frames to buffer when a track is lost. 23 | Increasing lost_track_buffer enhances occlusion handling, significantly 24 | improving tracking through occlusions, but may increase the possibility 25 | of ID switching for objects with similar appearance. 26 | frame_rate (float): Frame rate of the video (frames per second). 27 | Used to calculate the maximum time a track can be lost. 28 | track_activation_threshold (float): Detection confidence threshold 29 | for track activation. Only detections with confidence above this 30 | threshold will create new tracks. Increasing this threshold 31 | reduces false positives but may miss real objects with low confidence. 32 | minimum_consecutive_frames (int): Number of consecutive frames that an object 33 | must be tracked before it is considered a 'valid' track. Increasing 34 | `minimum_consecutive_frames` prevents the creation of accidental tracks 35 | from false detection or double detection, but risks missing shorter 36 | tracks. Before the tracker is considered valid, it will be assigned 37 | `-1` as its `tracker_id`. 38 | minimum_iou_threshold (float): IOU threshold for associating detections to 39 | existing tracks. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | lost_track_buffer: int = 30, 45 | frame_rate: float = 30.0, 46 | track_activation_threshold: float = 0.25, 47 | minimum_consecutive_frames: int = 3, 48 | minimum_iou_threshold: float = 0.3, 49 | ) -> None: 50 | # Calculate maximum frames without update based on lost_track_buffer and 51 | # frame_rate. This scales the buffer based on the frame rate to ensure 52 | # consistent time-based tracking across different frame rates. 53 | self.maximum_frames_without_update = int(frame_rate / 30.0 * lost_track_buffer) 54 | self.minimum_consecutive_frames = minimum_consecutive_frames 55 | self.minimum_iou_threshold = minimum_iou_threshold 56 | self.track_activation_threshold = track_activation_threshold 57 | 58 | # Active trackers 59 | self.trackers: list[SORTKalmanBoxTracker] = [] 60 | 61 | def _get_associated_indices( 62 | self, iou_matrix: np.ndarray, detection_boxes: np.ndarray 63 | ) -> tuple[list[tuple[int, int]], set[int], set[int]]: 64 | """ 65 | Associate detections to trackers based on IOU 66 | 67 | Args: 68 | iou_matrix (np.ndarray): IOU cost matrix. 69 | detection_boxes (np.ndarray): Detected bounding boxes in the 70 | form [x1, y1, x2, y2]. 71 | 72 | Returns: 73 | tuple[list[tuple[int, int]], set[int], set[int]]: Matched indices, 74 | unmatched trackers, unmatched detections. 75 | """ 76 | matched_indices = [] 77 | unmatched_trackers = set(range(len(self.trackers))) 78 | unmatched_detections = set(range(len(detection_boxes))) 79 | 80 | if len(self.trackers) > 0 and len(detection_boxes) > 0: 81 | # Find optimal assignment using scipy.optimize.linear_sum_assignment. 82 | # Note that it uses a a modified Jonker-Volgenant algorithm with no 83 | # initialization instead of the Hungarian algorithm as mentioned in the 84 | # SORT paper. 85 | row_indices, col_indices = linear_sum_assignment(iou_matrix, maximize=True) 86 | for row, col in zip(row_indices, col_indices): 87 | if iou_matrix[row, col] >= self.minimum_iou_threshold: 88 | matched_indices.append((row, col)) 89 | unmatched_trackers.remove(row) 90 | unmatched_detections.remove(col) 91 | 92 | return matched_indices, unmatched_trackers, unmatched_detections 93 | 94 | def _spawn_new_trackers( 95 | self, 96 | detections: sv.Detections, 97 | detection_boxes: np.ndarray, 98 | unmatched_detections: set[int], 99 | ) -> None: 100 | """ 101 | Create new trackers only for unmatched detections with confidence 102 | above threshold. 103 | 104 | Args: 105 | detections (sv.Detections): The latest set of object detections. 106 | detection_boxes (np.ndarray): Detected bounding boxes in the 107 | form [x1, y1, x2, y2]. 108 | """ 109 | for detection_idx in unmatched_detections: 110 | if ( 111 | detections.confidence is None 112 | or detection_idx >= len(detections.confidence) 113 | or detections.confidence[detection_idx] 114 | >= self.track_activation_threshold 115 | ): 116 | new_tracker = SORTKalmanBoxTracker(detection_boxes[detection_idx]) 117 | self.trackers.append(new_tracker) 118 | 119 | def update(self, detections: sv.Detections) -> sv.Detections: 120 | """Updates the tracker state with new detections. 121 | 122 | Performs Kalman filter prediction, associates detections with existing 123 | trackers based on IOU, updates matched trackers, and initializes new 124 | trackers for unmatched high-confidence detections. 125 | 126 | Args: 127 | detections (sv.Detections): The latest set of object detections from a frame. 128 | 129 | Returns: 130 | sv.Detections: A copy of the input detections, augmented with assigned 131 | `tracker_id` for each successfully tracked object. Detections not 132 | associated with a track will not have a `tracker_id`. 133 | """ # noqa: E501 134 | 135 | if len(self.trackers) == 0 and len(detections) == 0: 136 | detections.tracker_id = np.array([], dtype=int) 137 | return detections 138 | 139 | # Convert detections to a (N x 4) array (x1, y1, x2, y2) 140 | detection_boxes = ( 141 | detections.xyxy if len(detections) > 0 else np.array([]).reshape(0, 4) 142 | ) 143 | 144 | # Predict new locations for existing trackers 145 | for tracker in self.trackers: 146 | tracker.predict() 147 | 148 | # Build IOU cost matrix between detections and predicted bounding boxes 149 | iou_matrix = get_iou_matrix(self.trackers, detection_boxes) 150 | 151 | # Associate detections to trackers based on IOU 152 | matched_indices, _, unmatched_detections = self._get_associated_indices( 153 | iou_matrix, detection_boxes 154 | ) 155 | 156 | # Update matched trackers with assigned detections 157 | for row, col in matched_indices: 158 | self.trackers[row].update(detection_boxes[col]) 159 | 160 | self._spawn_new_trackers(detections, detection_boxes, unmatched_detections) 161 | 162 | # Remove dead trackers 163 | self.trackers = get_alive_trackers( 164 | self.trackers, 165 | self.minimum_consecutive_frames, 166 | self.maximum_frames_without_update, 167 | ) 168 | 169 | updated_detections = update_detections_with_track_ids( 170 | self.trackers, 171 | detections, 172 | detection_boxes, 173 | self.minimum_iou_threshold, 174 | self.minimum_consecutive_frames, 175 | ) 176 | 177 | return updated_detections 178 | 179 | def reset(self) -> None: 180 | """Resets the tracker's internal state. 181 | 182 | Clears all active tracks and resets the track ID counter. 183 | """ 184 | self.trackers = [] 185 | SORTKalmanBoxTracker.count_id = 0 186 | -------------------------------------------------------------------------------- /trackers/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from typing import Any, Dict, Final, Literal, Optional 5 | 6 | _LOG_LEVELS: Final[dict[str, int]] = { 7 | "DEBUG": logging.DEBUG, 8 | "INFO": logging.INFO, 9 | "WARNING": logging.WARNING, 10 | "ERROR": logging.ERROR, 11 | "CRITICAL": logging.CRITICAL, 12 | } 13 | 14 | _LOG_FILENAME: Final[str] = os.environ.get("TRACKERS_LOG_FILENAME", "trackers.log") 15 | _LOG_LEVEL_NAME: Final[str] = os.environ.get("TRACKERS_LOG_LEVEL", "ERROR").upper() 16 | _LOG_OUTPUT_TYPE: Final[str] = os.environ.get("TRACKERS_LOG_OUTPUT", "stderr").lower() 17 | _LOG_LEVEL: Final[int] = _LOG_LEVELS.get(_LOG_LEVEL_NAME, logging.ERROR) 18 | _LOG_FORMAT: Final[str] = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 19 | 20 | 21 | class LogFormatter(logging.Formatter): 22 | """ 23 | Custom log formatter that adds ANSI color codes to log messages based on 24 | the log level for terminal output. Does not add color codes if the output 25 | is redirected to a file. This formatter is designed to work with Python 3.10+. 26 | It uses ANSI escape sequences to colorize log messages for better visibility 27 | in terminal environments. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | fmt: Optional[str] = None, 33 | datefmt: Optional[str] = None, 34 | style: Literal["%", "{", "$"] = "%", 35 | validate: bool = True, 36 | *, 37 | defaults: Optional[Dict[str, Any]] = None, 38 | ) -> None: 39 | if sys.version_info >= (3, 10): 40 | super().__init__(fmt, datefmt, style, validate, defaults=defaults) 41 | else: 42 | super().__init__(fmt, datefmt, style, validate) 43 | 44 | self._RESET: Final[str] = "\x1b[0m" 45 | 46 | self._COLOURS: Final[dict[int, str]] = { 47 | logging.DEBUG: "\x1b[38;21m", 48 | logging.INFO: "\x1b[34;1m", 49 | logging.WARNING: "\x1b[33;1m", 50 | logging.ERROR: "\x1b[31;1m", 51 | logging.CRITICAL: "\x1b[35;1m", 52 | } 53 | 54 | self._BASE_FORMAT: Final[str] = "%(asctime)s - %(name)s - " 55 | self._LEVEL_MSG_FORMAT: Final[str] = "%(levelname)s: %(message)s" 56 | 57 | self._FORMATS: dict[int, str] = { 58 | level: color + self._BASE_FORMAT + self._LEVEL_MSG_FORMAT + self._RESET 59 | for level, color in self._COLOURS.items() 60 | } 61 | 62 | def format(self, record: logging.LogRecord) -> str: 63 | """ 64 | Formats the log record with color based on the log level. 65 | Args: 66 | record (logging.LogRecord): The log record to format. 67 | Returns: 68 | str: The formatted log message with color. 69 | """ 70 | 71 | log_fmt = self._FORMATS.get(record.levelno) 72 | formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S") 73 | return formatter.format(record) 74 | 75 | 76 | if _LOG_OUTPUT_TYPE == "file": 77 | logging.basicConfig( 78 | level=_LOG_LEVEL, 79 | format=_LOG_FORMAT, 80 | filename=_LOG_FILENAME, 81 | filemode="a", 82 | ) 83 | else: 84 | root_logger = logging.getLogger() 85 | root_logger.setLevel(_LOG_LEVEL) 86 | handler = logging.StreamHandler(sys.stderr) 87 | handler.setFormatter(LogFormatter()) 88 | root_logger.addHandler(handler) 89 | 90 | 91 | def get_logger(name: Optional[str]) -> logging.Logger: 92 | """ 93 | Retrieves a logger instance with the specified name. 94 | 95 | Args: 96 | name (str): The name for the logger, typically __name__. 97 | 98 | Returns: 99 | logging.Logger: Configured logger instance. 100 | """ 101 | return logging.getLogger(name) 102 | -------------------------------------------------------------------------------- /trackers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from trackers.utils.sort_utils import ( 2 | get_alive_trackers, 3 | get_iou_matrix, 4 | update_detections_with_track_ids, 5 | ) 6 | 7 | __all__ = [ 8 | "get_alive_trackers", 9 | "get_iou_matrix", 10 | "update_detections_with_track_ids", 11 | ] 12 | -------------------------------------------------------------------------------- /trackers/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | from trackers.log import get_logger 4 | 5 | logger = get_logger(__name__) 6 | 7 | 8 | def unzip_file(source_zip_path: str, target_dir_path: str) -> None: 9 | """ 10 | Extracts all files from a zip archive. 11 | 12 | Args: 13 | source_zip_path (str): The path to the zip file. 14 | target_dir_path (str): The directory to extract the contents to. 15 | If the directory doesn't exist, it will be created. 16 | 17 | Raises: 18 | FileNotFoundError: If the zip file doesn't exist. 19 | zipfile.BadZipFile: If the file is not a valid zip file or is corrupted. 20 | Exception: If any other error occurs during extraction. 21 | """ 22 | with zipfile.ZipFile(source_zip_path, "r") as zip_ref: 23 | zip_ref.extractall(target_dir_path) 24 | logger.info(f"Successfully extracted '{source_zip_path}' to '{target_dir_path}'") 25 | -------------------------------------------------------------------------------- /trackers/utils/downloader.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os.path 3 | import shutil 4 | from tempfile import TemporaryDirectory 5 | from typing import Optional 6 | from urllib.parse import urlparse 7 | 8 | import aiofiles 9 | import aiohttp 10 | from tqdm.asyncio import tqdm as async_tqdm 11 | 12 | 13 | class _AsyncFileDownloader: 14 | """Asynchronously downloads files with 15 | support for multipart downloading and progress bars. 16 | 17 | This class handles downloading files from URLs, automatically determining whether 18 | to use multipart downloading based on server support for content length. 19 | It displays progress using tqdm. 20 | """ 21 | 22 | def __init__(self, part_size_mb: int = 10, default_chunk_size: int = 8192): 23 | """Initializes the AsyncFileDownloader. 24 | 25 | Args: 26 | part_size_mb (int): The size of each part in megabytes for multipart downloads. 27 | default_chunk_size (int): The default chunk size in bytes for reading content. 28 | """ # noqa: E501 29 | self.part_size = part_size_mb * 1024**2 30 | self.default_chunk_size = default_chunk_size 31 | 32 | async def _get_content_length(self, url: str) -> Optional[int]: 33 | """Retrieves the content length of a file from a URL. 34 | 35 | Args: 36 | url (str): The URL of the file. 37 | 38 | Returns: 39 | Optional[int]: The content length in bytes if available, otherwise None. 40 | """ 41 | async with aiohttp.ClientSession() as session: 42 | async with session.head(url) as request: 43 | return request.content_length 44 | 45 | def _parts_generator(self, size: int, start: int = 0): 46 | """Generates byte ranges for multipart downloading. 47 | 48 | Args: 49 | size (int): The total size of the file in bytes. 50 | start (int): The starting byte offset. Defaults to 0. 51 | 52 | Yields: 53 | Tuple[int, int]: A tuple representing the start and end byte of a part. 54 | """ 55 | while size - start > self.part_size: 56 | yield start, start + self.part_size 57 | start += self.part_size 58 | yield start, size 59 | 60 | async def _download_part( 61 | self, url: str, headers: dict, save_path: str, progress_bar: async_tqdm 62 | ): 63 | """Downloads a single part of a file. 64 | 65 | Args: 66 | url (str): The URL to download from. 67 | headers (dict): HTTP headers to use for the request (e.g., for Range). 68 | save_path (str): The local path to save the downloaded part. 69 | progress_bar (async_tqdm): An instance of tqdm to update download progress. 70 | """ 71 | async with aiohttp.ClientSession(headers=headers) as session: 72 | async with session.get(url) as request: 73 | async with aiofiles.open(save_path, "wb") as file: 74 | async for chunk in request.content.iter_chunked( 75 | self.default_chunk_size 76 | ): 77 | await file.write(chunk) 78 | progress_bar.update(len(chunk)) 79 | 80 | async def process_url( 81 | self, 82 | url: str, 83 | save_dir: Optional[str] = None, 84 | output_filename: Optional[str] = None, 85 | ) -> str: 86 | """Downloads a file from a URL, handling multipart downloads and progress. 87 | 88 | If the server provides content length, the file is downloaded in parts. 89 | Otherwise, a direct download is attempted. Progress is displayed using tqdm. 90 | 91 | Args: 92 | url (str): The URL of the file to download. 93 | save_dir (Optional[str]): The directory to save the downloaded file. 94 | Defaults to the current working directory. 95 | output_filename (Optional[str]): The desired filename for the downloaded file. 96 | If None, it's inferred from the URL. 97 | 98 | Returns: 99 | str: The full path to the downloaded file. 100 | """ # noqa: E501 101 | if output_filename is None: 102 | output_filename = os.path.basename(urlparse(url).path) 103 | 104 | if save_dir is None: 105 | save_dir = os.path.abspath(".") 106 | final_save_path = os.path.join(save_dir, output_filename) 107 | os.makedirs(save_dir, exist_ok=True) 108 | tmp_dir = TemporaryDirectory(prefix=output_filename, dir=save_dir) 109 | try: 110 | size = await self._get_content_length(url) 111 | if size is None: 112 | async with aiohttp.ClientSession() as session: 113 | async with session.get(url) as request: 114 | content_length = request.content_length 115 | with async_tqdm( 116 | total=content_length, 117 | unit="B", 118 | unit_scale=True, 119 | desc=f"Downloading {output_filename}", 120 | leave=True, 121 | ) as pbar: 122 | async with aiofiles.open(final_save_path, "wb") as file: 123 | async for chunk in request.content.iter_chunked( 124 | self.default_chunk_size 125 | ): 126 | await file.write(chunk) 127 | pbar.update(len(chunk)) 128 | return final_save_path 129 | 130 | tasks = [] 131 | file_parts = [] 132 | with async_tqdm( 133 | total=size, 134 | unit="B", 135 | unit_scale=True, 136 | desc=f"Downloading {output_filename}", 137 | leave=True, 138 | ) as pbar: 139 | for number, sizes in enumerate(self._parts_generator(size)): 140 | part_file_name = os.path.join( 141 | tmp_dir.name, f"{output_filename}.part{number}" 142 | ) 143 | file_parts.append(part_file_name) 144 | tasks.append( 145 | self._download_part( 146 | url, 147 | {"Range": f"bytes={sizes[0]}-{sizes[1] - 1}"}, 148 | part_file_name, 149 | pbar, 150 | ) 151 | ) 152 | 153 | await asyncio.gather(*tasks) 154 | 155 | with open(final_save_path, "wb") as wfd: 156 | for f_part_path in file_parts: 157 | with open(f_part_path, "rb") as fd: 158 | shutil.copyfileobj(fd, wfd) 159 | return final_save_path 160 | finally: 161 | tmp_dir.cleanup() 162 | 163 | 164 | def download_file( 165 | url: str, part_size_mb: int = 10, default_chunk_size: int = 8192 166 | ) -> str: 167 | """Asynchronously downloads files with support for multipart downloading and progress bars. 168 | 169 | This class handles downloading files from URLs, automatically determining whether 170 | to use multipart downloading based on server support for content length. 171 | It displays progress using tqdm. 172 | 173 | Args: 174 | url (str): The URL to download the model file from. 175 | part_size_mb (int): The size of each part in megabytes for multipart downloads. 176 | default_chunk_size (int): The default chunk size in bytes for reading content. 177 | 178 | Returns: 179 | str: The local path to the downloaded file. 180 | """ # noqa: E501 181 | 182 | downloader = _AsyncFileDownloader( 183 | part_size_mb=part_size_mb, default_chunk_size=default_chunk_size 184 | ) 185 | if not url: 186 | raise ValueError("URL cannot be empty.") 187 | if not urlparse(url).scheme: 188 | raise ValueError("Invalid URL. Please provide a valid URL.") 189 | if not urlparse(url).netloc: 190 | raise ValueError("Invalid URL. Please provide a valid URL.") 191 | if not urlparse(url).path: 192 | raise ValueError("Invalid URL. Please provide a valid URL.") 193 | 194 | try: 195 | loop = asyncio.get_event_loop() 196 | if loop.is_running(): 197 | future = asyncio.ensure_future(downloader.process_url(url)) 198 | file_path = loop.run_until_complete(future) 199 | else: 200 | file_path = loop.run_until_complete(downloader.process_url(url)) 201 | except RuntimeError: 202 | file_path = asyncio.run(downloader.process_url(url)) 203 | print(f"File downloaded to {file_path}.") 204 | return file_path 205 | -------------------------------------------------------------------------------- /trackers/utils/sort_utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import List, Sequence, Set, TypeVar, Union 3 | 4 | import numpy as np 5 | import supervision as sv 6 | from supervision.detection.utils import box_iou_batch 7 | 8 | from trackers.core.deepsort.kalman_box_tracker import DeepSORTKalmanBoxTracker 9 | from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker 10 | 11 | KalmanBoxTrackerType = TypeVar( 12 | "KalmanBoxTrackerType", bound=Union[SORTKalmanBoxTracker, DeepSORTKalmanBoxTracker] 13 | ) 14 | 15 | 16 | def get_alive_trackers( 17 | trackers: Sequence[KalmanBoxTrackerType], 18 | minimum_consecutive_frames: int, 19 | maximum_frames_without_update: int, 20 | ) -> List[KalmanBoxTrackerType]: 21 | """ 22 | Remove dead or immature lost tracklets and get alive trackers 23 | that are within `maximum_frames_without_update` AND (it's mature OR 24 | it was just updated). 25 | 26 | Args: 27 | trackers (Sequence[KalmanBoxTrackerType]): List of KalmanBoxTracker objects. 28 | minimum_consecutive_frames (int): Number of consecutive frames that an object 29 | must be tracked before it is considered a 'valid' track. 30 | maximum_frames_without_update (int): Maximum number of frames without update 31 | before a track is considered dead. 32 | 33 | Returns: 34 | List[KalmanBoxTrackerType]: List of alive trackers. 35 | """ 36 | alive_trackers = [] 37 | for tracker in trackers: 38 | is_mature = tracker.number_of_successful_updates >= minimum_consecutive_frames 39 | is_active = tracker.time_since_update == 0 40 | if tracker.time_since_update < maximum_frames_without_update and ( 41 | is_mature or is_active 42 | ): 43 | alive_trackers.append(tracker) 44 | return alive_trackers 45 | 46 | 47 | def get_iou_matrix( 48 | trackers: Sequence[KalmanBoxTrackerType], detection_boxes: np.ndarray 49 | ) -> np.ndarray: 50 | """ 51 | Build IOU cost matrix between detections and predicted bounding boxes 52 | 53 | Args: 54 | detection_boxes (np.ndarray): Detected bounding boxes in the 55 | form [x1, y1, x2, y2]. 56 | 57 | Returns: 58 | np.ndarray: IOU cost matrix. 59 | """ 60 | predicted_boxes = np.array([t.get_state_bbox() for t in trackers]) 61 | if len(predicted_boxes) == 0 and len(trackers) > 0: 62 | # Handle case where get_state_bbox might return empty array 63 | predicted_boxes = np.zeros((len(trackers), 4), dtype=np.float32) 64 | 65 | if len(trackers) > 0 and len(detection_boxes) > 0: 66 | iou_matrix = box_iou_batch(predicted_boxes, detection_boxes) 67 | else: 68 | iou_matrix = np.zeros((len(trackers), len(detection_boxes)), dtype=np.float32) 69 | 70 | return iou_matrix 71 | 72 | 73 | def update_detections_with_track_ids( 74 | trackers: Sequence[KalmanBoxTrackerType], 75 | detections: sv.Detections, 76 | detection_boxes: np.ndarray, 77 | minimum_iou_threshold: float, 78 | minimum_consecutive_frames: int, 79 | ) -> sv.Detections: 80 | """ 81 | The function prepares the updated Detections with track IDs. 82 | If a tracker is "mature" (>= `minimum_consecutive_frames`) or recently updated, 83 | it is assigned an ID to the detection that just updated it. 84 | 85 | Args: 86 | trackers (Sequence[SORTKalmanBoxTracker]): List of SORTKalmanBoxTracker objects. 87 | detections (sv.Detections): The latest set of object detections. 88 | detection_boxes (np.ndarray): Detected bounding boxes in the 89 | form [x1, y1, x2, y2]. 90 | minimum_iou_threshold (float): IOU threshold for associating detections to 91 | existing tracks. 92 | minimum_consecutive_frames (int): Number of consecutive frames that an object 93 | must be tracked before it is considered a 'valid' track. 94 | 95 | Returns: 96 | sv.Detections: A copy of the detections with `tracker_id` set 97 | for each detection that is tracked. 98 | """ 99 | # Re-run association in the same way (could also store direct mapping) 100 | final_tracker_ids = [-1] * len(detection_boxes) 101 | 102 | # Recalculate predicted_boxes based on current trackers after some may have 103 | # been removed 104 | predicted_boxes = np.array([t.get_state_bbox() for t in trackers]) 105 | iou_matrix_final = np.zeros((len(trackers), len(detection_boxes)), dtype=np.float32) 106 | 107 | # Ensure predicted_boxes is properly shaped before the second iou calculation 108 | if len(predicted_boxes) == 0 and len(trackers) > 0: 109 | predicted_boxes = np.zeros((len(trackers), 4), dtype=np.float32) 110 | 111 | if len(trackers) > 0 and len(detection_boxes) > 0: 112 | iou_matrix_final = box_iou_batch(predicted_boxes, detection_boxes) 113 | 114 | row_indices, col_indices = np.where(iou_matrix_final > minimum_iou_threshold) 115 | sorted_pairs = sorted( 116 | zip(row_indices, col_indices), 117 | key=lambda x: iou_matrix_final[x[0], x[1]], 118 | reverse=True, 119 | ) 120 | used_rows: Set[int] = set() 121 | used_cols: Set[int] = set() 122 | for row, col in sorted_pairs: 123 | # Double check index is in range 124 | if row < len(trackers): 125 | tracker_obj = trackers[int(row)] 126 | # Only assign if the track is "mature" or is new but has enough hits 127 | if (int(row) not in used_rows) and (int(col) not in used_cols): 128 | if ( 129 | tracker_obj.number_of_successful_updates 130 | >= minimum_consecutive_frames 131 | ): 132 | # If tracker is mature but still has ID -1, assign a new ID 133 | if tracker_obj.tracker_id == -1: 134 | tracker_obj.tracker_id = ( 135 | SORTKalmanBoxTracker.get_next_tracker_id() 136 | ) 137 | final_tracker_ids[int(col)] = tracker_obj.tracker_id 138 | used_rows.add(int(row)) 139 | used_cols.add(int(col)) 140 | 141 | # Assign tracker IDs to the returned Detections 142 | updated_detections = deepcopy(detections) 143 | updated_detections.tracker_id = np.array(final_tracker_ids) 144 | 145 | return updated_detections 146 | -------------------------------------------------------------------------------- /trackers/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Any, Tuple, Union 4 | 5 | import torch 6 | from safetensors import safe_open 7 | 8 | 9 | def parse_device_spec(device_spec: Union[str, torch.device]) -> torch.device: 10 | """ 11 | Convert a string or torch.device into a valid torch.device. Allowed strings: 12 | `'auto'`, `'cpu'`, `'cuda'`, `'cuda:N'` (e.g. `'cuda:0'`), or `'mps'`. 13 | This function raises ValueError if the input is unrecognized or the GPU 14 | index is out of range. 15 | 16 | Args: 17 | device_spec (Union[str, torch.device]): A specification for the device. 18 | This can be a valid `torch.device` object or one of the recognized 19 | strings described above. 20 | 21 | Returns: 22 | torch.device: The corresponding `torch.device` object. 23 | 24 | Raises: 25 | ValueError: If the device specification is unrecognized or the provided GPU 26 | index exceeds the available devices. 27 | """ 28 | if isinstance(device_spec, torch.device): 29 | return device_spec 30 | 31 | device_str = device_spec.lower() 32 | if device_str == "auto": 33 | if torch.cuda.is_available(): 34 | return torch.device("cuda") 35 | elif torch.backends.mps.is_available(): 36 | return torch.device("mps") 37 | else: 38 | return torch.device("cpu") 39 | elif device_str == "cpu": 40 | return torch.device("cpu") 41 | elif device_str == "cuda": 42 | return torch.device("cuda") 43 | elif device_str == "mps": 44 | return torch.device("mps") 45 | else: 46 | match = re.match(r"^cuda:(\d+)$", device_str) 47 | if match: 48 | index = int(match.group(1)) 49 | if index < 0: 50 | raise ValueError(f"GPU index must be non-negative, got {index}.") 51 | if index >= torch.cuda.device_count(): 52 | raise ValueError( 53 | f"Requested cuda:{index} but only {torch.cuda.device_count()}" 54 | + " GPU(s) are available." 55 | ) 56 | return torch.device(f"cuda:{index}") 57 | 58 | raise ValueError(f"Unrecognized device spec: {device_spec}") 59 | 60 | 61 | def load_safetensors_checkpoint( 62 | checkpoint_path: str, device: str = "cpu" 63 | ) -> Tuple[dict[str, torch.Tensor], dict[str, Any]]: 64 | """ 65 | Load a safetensors checkpoint into a dictionary of tensors and a dictionary 66 | of metadata. 67 | 68 | Args: 69 | checkpoint_path (str): The path to the safetensors checkpoint. 70 | device (str): The device to load the checkpoint on. 71 | 72 | Returns: 73 | Tuple[dict[str, torch.Tensor], dict[str, Any]]: A tuple containing the 74 | state_dict and the config. 75 | """ 76 | state_dict = {} 77 | with safe_open(checkpoint_path, framework="pt", device=device) as f: 78 | for key in f.keys(): 79 | state_dict[key] = f.get_tensor(key) 80 | metadata = f.metadata() 81 | config = json.loads(metadata["config"]) if "config" in metadata else {} 82 | model_metadata = config.pop("model_metadata") if "model_metadata" in config else {} 83 | if "kwargs" in model_metadata: 84 | kwargs = model_metadata.pop("kwargs") 85 | model_metadata = {**kwargs, **model_metadata} 86 | config["model_metadata"] = model_metadata 87 | return state_dict, config 88 | --------------------------------------------------------------------------------