├── .coveragerc ├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── labeler.yml ├── release-drafter.yml └── workflows │ ├── ci.yml │ ├── docker.yml │ ├── labeler.yml │ ├── matchers │ ├── flake8.json │ ├── isort.json │ ├── mypy.json │ ├── pylint.json │ └── pytest.json │ ├── release-drafter.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── tasks.json ├── Dockerfile ├── LICENSE ├── MAINTAINERS.md ├── MANIFEST.in ├── README.rst ├── aioesphomeapi ├── __init__.py ├── _frame_helper │ ├── __init__.py │ ├── base.pxd │ ├── base.py │ ├── noise.pxd │ ├── noise.py │ ├── noise_encryption.pxd │ ├── noise_encryption.py │ ├── pack.pxd │ ├── pack.pyx │ ├── packets.pxd │ ├── packets.py │ ├── plain_text.pxd │ └── plain_text.py ├── api.proto ├── api_options.proto ├── api_options_pb2.py ├── api_pb2.py ├── ble_defs.py ├── client.py ├── client_base.pxd ├── client_base.py ├── connection.pxd ├── connection.py ├── core.py ├── discover.py ├── host_resolver.py ├── log_parser.py ├── log_reader.py ├── log_runner.py ├── model.py ├── model_conversions.py ├── py.typed ├── reconnect_logic.py ├── util.py └── zeroconf.py ├── bench ├── raw_ble_plain_text.py └── raw_ble_plain_text_with_callback.py ├── mypy.ini ├── pyproject.toml ├── requirements ├── base.txt └── test.txt ├── script ├── gen-protoc └── lint ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── benchmarks ├── __init__.py ├── conftest.py ├── test_bluetooth.py ├── test_noise.py └── test_requests.py ├── common.py ├── conftest.py ├── test__frame_helper.py ├── test_client.py ├── test_connection.py ├── test_core.py ├── test_host_resolver.py ├── test_log_parser.py ├── test_log_runner.py ├── test_model.py ├── test_reconnect_logic.py ├── test_util.py └── test_zeroconf.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = aioesphomeapi 3 | 4 | omit = 5 | aioesphomeapi/api_options_pb2.py 6 | aioesphomeapi/api_pb2.py 7 | aioesphomeapi/log_reader.py 8 | aioesphomeapi/discover.py 9 | bench/*.py 10 | 11 | [report] 12 | # Regexes for lines to exclude from consideration 13 | exclude_lines = 14 | # Have to re-enable the standard pragma 15 | pragma: no cover 16 | 17 | # Don't complain about missing debug-only code: 18 | def __repr__ 19 | 20 | # Don't complain if tests don't hit defensive assertion code: 21 | raise AssertionError 22 | raise NotImplementedError 23 | raise exceptions.NotSupportedError 24 | 25 | # TYPE_CHECKING and @overload blocks are never executed during pytest run 26 | if TYPE_CHECKING: 27 | @overload 28 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ESPHome API Client Dev", 3 | "image": "ghcr.io/esphome/aioesphomeapi-proto-builder:latest", 4 | "features": { 5 | "ghcr.io/devcontainers/features/github-cli:1": {} 6 | }, 7 | "customizations": { 8 | "vscode": { 9 | "extensions": [ 10 | "ms-python.python", 11 | "visualstudioexptteam.vscodeintellicode", 12 | // yaml 13 | "redhat.vscode-yaml", 14 | // editorconfig 15 | "editorconfig.editorconfig", 16 | // protobuf 17 | "pbkit.vscode-pbkit" 18 | ], 19 | "settings": { 20 | "python.languageServer": "Pylance", 21 | "python.pythonPath": "/usr/bin/python3", 22 | "python.formatting.provider": "black", 23 | "editor.formatOnPaste": false, 24 | "editor.formatOnSave": true, 25 | "editor.formatOnType": true, 26 | "files.trimTrailingWhitespace": true, 27 | "terminal.integrated.defaultProfile.linux": "bash", 28 | "files.exclude": { 29 | "**/.git": true, 30 | "**/.DS_Store": true, 31 | "**/*.pyc": { 32 | "when": "$(basename).py" 33 | }, 34 | "**/__pycache__": true 35 | }, 36 | "files.associations": { 37 | "**/.vscode/*.json": "jsonc" 38 | } 39 | } 40 | } 41 | }, 42 | "postCreateCommand": "pip3 install -e ." 43 | } 44 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Hide sublime text stuff 2 | *.sublime-project 3 | *.sublime-workspace 4 | 5 | # Hide some OS X stuff 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | Icon 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # IntelliJ IDEA 15 | .idea 16 | *.iml 17 | 18 | # pytest 19 | .pytest_cache 20 | .cache 21 | 22 | # GITHUB Proposed Python stuff: 23 | *.py[cod] 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Packages 29 | *.egg 30 | *.egg-info 31 | dist 32 | build 33 | eggs 34 | .eggs 35 | parts 36 | bin 37 | var 38 | sdist 39 | develop-eggs 40 | .installed.cfg 41 | lib 42 | lib64 43 | 44 | # Logs 45 | *.log 46 | pip-log.txt 47 | 48 | # Unit test / coverage reports 49 | .coverage 50 | .tox 51 | nosetests.xml 52 | htmlcov/ 53 | 54 | # Translations 55 | *.mo 56 | 57 | # Mr Developer 58 | .mr.developer.cfg 59 | .project 60 | .pydevproject 61 | 62 | .python-version 63 | 64 | # emacs auto backups 65 | *~ 66 | *# 67 | *.orig 68 | 69 | # venv stuff 70 | pyvenv.cfg 71 | pip-selfcheck.json 72 | venv 73 | .venv 74 | Pipfile* 75 | share/* 76 | 77 | # vimmy stuff 78 | *.swp 79 | *.swo 80 | 81 | ctags.tmp 82 | 83 | # Visual Studio Code 84 | .vscode 85 | 86 | # Built docs 87 | docs/build 88 | 89 | # Windows Explorer 90 | desktop.ini 91 | /.vs/* 92 | 93 | # mypy 94 | /.mypy_cache/* 95 | 96 | .git* 97 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this implement/fix? 2 | 3 | 4 | 5 | ## Types of changes 6 | 7 | 14 | 15 | - [ ] Bugfix (non-breaking change which fixes an issue) 16 | - [ ] New feature (non-breaking change which adds functionality) 17 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 18 | - [ ] Code quality improvements to existing code or addition of tests 19 | - [ ] Other 20 | 21 | **Related issue or feature (if applicable):** 22 | 23 | - fixes 24 | 25 | **Pull request in [esphome](https://github.com/esphome/esphome) (if applicable):** 26 | 27 | - esphome/esphome# 28 | 29 | ## Checklist: 30 | - [ ] The code change is tested and works locally. 31 | - [ ] If api.proto was modified, a linked pull request has been made to [esphome](https://github.com/esphome/esphome) with the same changes. 32 | - [ ] Tests have been added to verify that the new code works (under `tests/` folder). 33 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # dependabot config 2 | version: 2 3 | updates: 4 | - package-ecosystem: "pip" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | dependencies: 2 | - changed-files: 3 | - any-glob-to-any-file: 4 | - '.pre-commit-config.yaml' 5 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | categories: 4 | - title: "Breaking Changes" 5 | label: "breaking-change" 6 | - title: "Dependencies" 7 | collapse-after: 1 8 | labels: 9 | - "dependencies" 10 | 11 | version-resolver: 12 | major: 13 | labels: 14 | - "major" 15 | - "breaking-change" 16 | minor: 17 | labels: 18 | - "minor" 19 | - "new-feature" 20 | patch: 21 | labels: 22 | - "bugfix" 23 | - "dependencies" 24 | - "documentation" 25 | - "enhancement" 26 | default: patch 27 | 28 | template: | 29 | ## What's Changed 30 | 31 | $CHANGES 32 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | permissions: 9 | contents: read 10 | 11 | concurrency: 12 | # yamllint disable-line rule:line-length 13 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | ci: 18 | name: ${{ matrix.name }} py ${{ matrix.python-version }} on ${{ matrix.os }} (${{ matrix.extension }}) 19 | runs-on: ${{ matrix.os }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | python-version: 24 | - "3.9" 25 | - "3.10" 26 | - "3.11" 27 | - "3.12" 28 | - "3.13" 29 | os: 30 | - ubuntu-latest 31 | - windows-latest 32 | - macos-latest 33 | extension: 34 | - "skip_cython" 35 | - "use_cython" 36 | exclude: 37 | - python-version: "3.9" 38 | os: windows-latest 39 | - python-version: "3.10" 40 | os: windows-latest 41 | - python-version: "3.11" 42 | os: windows-latest 43 | - python-version: "3.13" 44 | os: windows-latest 45 | - python-version: "3.9" 46 | os: macos-latest 47 | - python-version: "3.10" 48 | os: macos-latest 49 | - python-version: "3.11" 50 | os: macos-latest 51 | - python-version: "3.13" 52 | os: macos-lates 53 | - extension: "use_cython" 54 | os: windows-latest 55 | - extension: "use_cython" 56 | os: macos-latest 57 | steps: 58 | - uses: actions/checkout@v4 59 | - name: Set up Python 60 | uses: actions/setup-python@v5 61 | id: python 62 | with: 63 | python-version: ${{ matrix.python-version }} 64 | allow-prereleases: true 65 | 66 | - name: Get pip cache dir 67 | id: pip-cache 68 | shell: bash 69 | run: | 70 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 71 | - name: Restore PIP cache 72 | uses: actions/cache@v4 73 | with: 74 | path: ${{ steps.pip-cache.outputs.dir }} 75 | key: pip-${{ steps.python.outputs.python-version }}-${{ matrix.extension }}-${{ hashFiles('requirements/base.txt', 'requirements/test.txt') }} 76 | restore-keys: | 77 | pip-${{ steps.python.outputs.python-version }}-${{ matrix.extension }}- 78 | - name: Set up Python environment (no cython) 79 | if: ${{ matrix.extension == 'skip_cython' }} 80 | env: 81 | SKIP_CYTHON: 1 82 | shell: bash 83 | run: | 84 | pip3 install -r requirements/base.txt -r requirements/test.txt 85 | pip3 install -e . 86 | - name: Set up Python environment (cython) 87 | if: ${{ matrix.extension == 'use_cython' }} 88 | env: 89 | REQUIRE_CYTHON: 1 90 | shell: bash 91 | run: | 92 | pip3 install -r requirements/base.txt -r requirements/test.txt 93 | pip3 install -e . 94 | - name: Register problem matchers 95 | shell: bash 96 | run: | 97 | echo "::add-matcher::.github/workflows/matchers/flake8.json" 98 | echo "::add-matcher::.github/workflows/matchers/pylint.json" 99 | echo "::add-matcher::.github/workflows/matchers/mypy.json" 100 | echo "::add-matcher::.github/workflows/matchers/pytest.json" 101 | 102 | - run: flake8 aioesphomeapi 103 | name: Lint with flake8 104 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 105 | - run: ruff format --check aioesphomeapi tests 106 | name: Check ruff formatting 107 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 108 | - run: ruff check aioesphomeapi tests 109 | name: Check with ruff 110 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 111 | - run: mypy aioesphomeapi 112 | name: Check typing with mypy 113 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 114 | - run: | 115 | docker run \ 116 | -v "$PWD":/aioesphomeapi \ 117 | -u "$(id -u):$(id -g)" \ 118 | ghcr.io/esphome/aioesphomeapi-proto-builder:latest 119 | if ! git diff --quiet; then 120 | echo "You have altered the generated proto files but they do not match what is expected." 121 | echo "Please run the following to update the generated files:" 122 | echo 'docker run -v "$PWD":/aioesphomeapi ghcr.io/esphome/aioesphomeapi-proto-builder:latest' 123 | exit 1 124 | fi 125 | name: Check protobuf files match 126 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 127 | - name: Show changes 128 | run: git diff 129 | if: ${{ failure() && matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 130 | - name: Archive artifacts 131 | uses: actions/upload-artifact@v4 132 | with: 133 | name: genrated-proto-files 134 | path: aioesphomeapi/*pb2.py 135 | if: ${{ failure() && matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }} 136 | - run: pytest -vv --cov=aioesphomeapi --cov-report=xml --timeout=4 --tb=native tests 137 | name: Run tests with pytest 138 | - name: Upload coverage to Codecov 139 | uses: codecov/codecov-action@v5 140 | env: 141 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 142 | 143 | benchmarks: 144 | name: Run benchmarks 145 | runs-on: ubuntu-22.04 146 | steps: 147 | - uses: actions/checkout@v4 148 | - uses: actions/setup-python@v5 149 | with: 150 | python-version: "3.13" 151 | cache: 'pip' # caching pip dependencies 152 | - name: Set up Python environment (cython) 153 | env: 154 | REQUIRE_CYTHON: 1 155 | shell: bash 156 | run: | 157 | pip3 install -r requirements/base.txt -r requirements/test.txt 158 | pip3 install -e . 159 | - name: Run benchmarks 160 | uses: CodSpeedHQ/action@v3 161 | with: 162 | token: ${{ secrets.CODSPEED_TOKEN }} 163 | run: pytest tests/benchmarks --codspeed 164 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Build docker image on changes 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - Dockerfile 9 | - requirements/test.txt 10 | - .github/workflows/docker.yml 11 | 12 | permissions: 13 | contents: read 14 | packages: write 15 | 16 | concurrency: 17 | # yamllint disable-line rule:line-length 18 | group: ${{ github.workflow }}-${{ github.ref }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | build-image: 23 | runs-on: ubuntu-latest 24 | name: Build and push Docker image 25 | steps: 26 | - 27 | name: Checkout source code 28 | uses: actions/checkout@v4 29 | - 30 | name: Log in to docker hub 31 | uses: docker/login-action@v3.4.0 32 | with: 33 | username: ${{ secrets.DOCKER_USER }} 34 | password: ${{ secrets.DOCKER_PASSWORD }} 35 | - 36 | name: Log in to the GitHub container registry 37 | uses: docker/login-action@v3.4.0 38 | with: 39 | registry: ghcr.io 40 | username: ${{ github.actor }} 41 | password: ${{ secrets.GITHUB_TOKEN }} 42 | - 43 | name: Set up QEMU 44 | uses: docker/setup-qemu-action@v3.4.0 45 | - 46 | name: Set up Docker Buildx 47 | uses: docker/setup-buildx-action@v3.10.0 48 | - 49 | name: Build and Push 50 | uses: docker/build-push-action@v6.18.0 51 | with: 52 | context: . 53 | tags: | 54 | ghcr.io/esphome/aioesphomeapi-proto-builder:latest 55 | esphome/aioesphomeapi-proto-builder:latest 56 | push: true 57 | pull: true 58 | cache-to: type=inline 59 | cache-from: ghcr.io/esphome/aioesphomeapi-proto-builder:latest 60 | platforms: linux/amd64,linux/arm64 61 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: Add labels 2 | 3 | on: 4 | - pull_request_target 5 | 6 | jobs: 7 | labeler: 8 | permissions: 9 | contents: read 10 | pull-requests: write 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/labeler@v5.0.0 14 | -------------------------------------------------------------------------------- /.github/workflows/matchers/flake8.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "flake8-error", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^(.*):(\\d+):(\\d+):\\s([EF]\\d{3}\\s.*)$", 9 | "file": 1, 10 | "line": 2, 11 | "column": 3, 12 | "message": 4 13 | } 14 | ] 15 | }, 16 | { 17 | "owner": "flake8-warning", 18 | "severity": "warning", 19 | "pattern": [ 20 | { 21 | "regexp": "^(.*):(\\d+):(\\d+):\\s([CDNW]\\d{3}\\s.*)$", 22 | "file": 1, 23 | "line": 2, 24 | "column": 3, 25 | "message": 4 26 | } 27 | ] 28 | } 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/matchers/isort.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "isort", 5 | "pattern": [ 6 | { 7 | "regexp": "^ERROR:\\s+(.+)\\s+(.+)$", 8 | "file": 1, 9 | "message": 2 10 | } 11 | ] 12 | } 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /.github/workflows/matchers/mypy.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "mypy", 5 | "pattern": [ 6 | { 7 | "regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$", 8 | "file": 1, 9 | "line": 2, 10 | "severity": 3, 11 | "message": 4 12 | } 13 | ] 14 | } 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /.github/workflows/matchers/pylint.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "pylint-error", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^(.+):(\\d+):(\\d+):\\s(([EF]\\d{4}):\\s.+)$", 9 | "file": 1, 10 | "line": 2, 11 | "column": 3, 12 | "message": 4, 13 | "code": 5 14 | } 15 | ] 16 | }, 17 | { 18 | "owner": "pylint-warning", 19 | "severity": "warning", 20 | "pattern": [ 21 | { 22 | "regexp": "^(.+):(\\d+):(\\d+):\\s(([CRW]\\d{4}):\\s.+)$", 23 | "file": 1, 24 | "line": 2, 25 | "column": 3, 26 | "message": 4, 27 | "code": 5 28 | } 29 | ] 30 | } 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /.github/workflows/matchers/pytest.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "pytest", 5 | "fileLocation": "absolute", 6 | "pattern": [ 7 | { 8 | "regexp": "^\\s+File \"(.*)\", line (\\d+), in (.*)$", 9 | "file": 1, 10 | "line": 2 11 | }, 12 | { 13 | "regexp": "^\\s+(.*)$", 14 | "message": 1 15 | } 16 | ] 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Release Drafter 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | 9 | permissions: 10 | contents: write 11 | pull-requests: read 12 | 13 | jobs: 14 | update_release_draft: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: release-drafter/release-drafter@v6 18 | id: release-draft 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | 22 | # Update version number in setup.py 23 | - uses: actions/checkout@v4 24 | - name: Calculate version 25 | id: version 26 | run: | 27 | tag="${{ steps.release-draft.outputs.tag_name }}" 28 | echo "version=${tag:1}" >> $GITHUB_OUTPUT 29 | 30 | - name: Set setup.py version 31 | run: | 32 | sed -i "s/VERSION = .*/VERSION = \"${{ steps.version.outputs.version }}\"/g" setup.py 33 | # github actions email from here: https://github.community/t/github-actions-bot-email-address/17204 34 | - name: Commit changes 35 | run: | 36 | if ! git diff --quiet; then 37 | git config --global user.name "github-actions[bot]" 38 | git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" 39 | git commit -am "Bump version to ${{ steps.version.outputs.version }}" 40 | git push 41 | fi 42 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | build_wheels: 12 | name: Wheels for ${{ matrix.os }} (${{ matrix.musl == 'musllinux' && 'musllinux' || 'manylinux' }}) ${{ matrix.qemu }} ${{ matrix.pyver }} 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: 17 | [ 18 | ubuntu-24.04-arm, 19 | ubuntu-latest, 20 | macos-13, 21 | macos-latest, 22 | ] 23 | qemu: [""] 24 | musl: [""] 25 | pyver: [""] 26 | include: 27 | - os: ubuntu-latest 28 | musl: "musllinux" 29 | - os: ubuntu-24.04-arm 30 | musl: "musllinux" 31 | # qemu is slow, make a single 32 | # runner per Python version 33 | - os: ubuntu-latest 34 | qemu: armv7l 35 | musl: "musllinux" 36 | pyver: cp39 37 | - os: ubuntu-latest 38 | qemu: armv7l 39 | musl: "musllinux" 40 | pyver: cp310 41 | - os: ubuntu-latest 42 | qemu: armv7l 43 | musl: "musllinux" 44 | pyver: cp311 45 | - os: ubuntu-latest 46 | qemu: armv7l 47 | musl: "musllinux" 48 | pyver: cp312 49 | - os: ubuntu-latest 50 | qemu: armv7l 51 | musl: "musllinux" 52 | pyver: cp313 53 | # qemu is slow, make a single 54 | # runner per Python version 55 | - os: ubuntu-latest 56 | qemu: armv7l 57 | musl: "" 58 | pyver: cp39 59 | - os: ubuntu-latest 60 | qemu: armv7l 61 | musl: "" 62 | pyver: cp310 63 | - os: ubuntu-latest 64 | qemu: armv7l 65 | musl: "" 66 | pyver: cp311 67 | - os: ubuntu-latest 68 | qemu: armv7l 69 | musl: "" 70 | pyver: cp312 71 | - os: ubuntu-latest 72 | qemu: armv7l 73 | musl: "" 74 | pyver: cp313 75 | steps: 76 | - uses: actions/checkout@v4 77 | with: 78 | fetch-depth: 0 79 | # Used to host cibuildwheel 80 | - name: Set up Python 81 | uses: actions/setup-python@v5 82 | with: 83 | python-version: "3.12" 84 | - name: Set up QEMU 85 | if: ${{ matrix.qemu }} 86 | uses: docker/setup-qemu-action@v3 87 | with: 88 | platforms: all 89 | # This should be temporary 90 | # xref https://github.com/docker/setup-qemu-action/issues/188 91 | # xref https://github.com/tonistiigi/binfmt/issues/215 92 | image: tonistiigi/binfmt:qemu-v8.1.5 93 | id: qemu 94 | - name: Prepare emulation 95 | if: ${{ matrix.qemu }} 96 | run: | 97 | if [[ -n "${{ matrix.qemu }}" ]]; then 98 | # Build emulated architectures only if QEMU is set, 99 | # use default "auto" otherwise 100 | echo "CIBW_ARCHS_LINUX=${{ matrix.qemu }}" >> $GITHUB_ENV 101 | fi 102 | - name: Limit to a specific Python version on slow QEMU 103 | if: ${{ matrix.pyver }} 104 | run: | 105 | if [[ -n "${{ matrix.pyver }}" ]]; then 106 | echo "CIBW_BUILD=${{ matrix.pyver }}*" >> $GITHUB_ENV 107 | fi 108 | - name: Build wheels 109 | uses: pypa/cibuildwheel@v2.23.3 110 | env: 111 | CIBW_SKIP: cp36-* cp37-* cp38-* pp* ${{ matrix.musl == 'musllinux' && '*manylinux*' || '*musllinux*' }} 112 | CIBW_BEFORE_ALL_LINUX: apt-get install -y gcc || yum install -y gcc || apk add gcc 113 | REQUIRE_CYTHON: 1 114 | 115 | - uses: actions/upload-artifact@v4 116 | with: 117 | name: wheels-${{ matrix.os }}-${{ matrix.musl }}-${{ matrix.pyver }}-${{ matrix.qemu }} 118 | path: ./wheelhouse/*.whl 119 | 120 | build_sdist: 121 | name: Build source distribution 122 | runs-on: ubuntu-latest 123 | steps: 124 | - uses: actions/checkout@v4 125 | 126 | - name: Build sdist 127 | run: pipx run build --sdist 128 | 129 | - uses: actions/upload-artifact@v4 130 | with: 131 | name: sdist 132 | path: dist/*.tar.gz 133 | 134 | upload_pypi: 135 | needs: [build_wheels, build_sdist] 136 | runs-on: ubuntu-latest 137 | environment: pypi 138 | permissions: 139 | id-token: write 140 | if: github.event_name == 'release' && github.event.action == 'published' 141 | steps: 142 | - name: Download artifacts 143 | uses: actions/download-artifact@v4 144 | with: 145 | path: dist 146 | merge-multiple: true 147 | 148 | - name: Publish to PyPI 149 | uses: pypa/gh-action-pypi-publish@v1.12.4 150 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Hide sublime text stuff 2 | *.sublime-project 3 | *.sublime-workspace 4 | 5 | # Hide some OS X stuff 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | Icon 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # IntelliJ IDEA 15 | .idea 16 | *.iml 17 | 18 | # pytest 19 | .pytest_cache 20 | .cache 21 | 22 | # GITHUB Proposed Python stuff: 23 | *.py[cod] 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Packages 29 | *.egg 30 | *.egg-info 31 | dist 32 | build 33 | eggs 34 | .eggs 35 | parts 36 | bin 37 | var 38 | sdist 39 | develop-eggs 40 | .installed.cfg 41 | lib 42 | lib64 43 | 44 | # Logs 45 | *.log 46 | pip-log.txt 47 | 48 | # Unit test / coverage reports 49 | .coverage 50 | .tox 51 | nosetests.xml 52 | htmlcov/ 53 | 54 | # Translations 55 | *.mo 56 | 57 | # Mr Developer 58 | .mr.developer.cfg 59 | .project 60 | .pydevproject 61 | 62 | .python-version 63 | 64 | # emacs auto backups 65 | *~ 66 | *# 67 | *.orig 68 | 69 | # venv stuff 70 | pyvenv.cfg 71 | pip-selfcheck.json 72 | venv 73 | .venv 74 | Pipfile* 75 | share/* 76 | 77 | # vimmy stuff 78 | *.swp 79 | *.swo 80 | 81 | ctags.tmp 82 | 83 | # Visual Studio Code 84 | .vscode/ 85 | !.vscode/tasks.json 86 | 87 | # Built docs 88 | docs/build 89 | 90 | # Windows Explorer 91 | desktop.ini 92 | /.vs/* 93 | 94 | # mypy 95 | /.mypy_cache/* 96 | 97 | # cython generated source 98 | *.c 99 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | exclude: '^aioesphomeapi/api.*$' 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-added-large-files 10 | - repo: https://github.com/asottile/pyupgrade 11 | rev: v3.20.0 12 | hooks: 13 | - id: pyupgrade 14 | args: [--py39-plus] 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | rev: v0.11.13 17 | hooks: 18 | - id: ruff 19 | args: [--fix] 20 | - id: ruff-format 21 | - repo: https://github.com/cdce8p/python-typing-update 22 | rev: v0.7.2 23 | hooks: 24 | - id: python-typing-update 25 | stages: [manual] 26 | args: 27 | - --py39-plus 28 | - --force 29 | - --keep-updates 30 | files: ^(aioesphomeapi)/.+\.py$ 31 | - repo: https://github.com/MarcoGorelli/cython-lint 32 | rev: v0.16.6 33 | hooks: 34 | - id: cython-lint 35 | - id: double-quote-cython-strings 36 | - repo: https://github.com/pre-commit/mirrors-mypy 37 | rev: v1.16.0 38 | hooks: 39 | - id: mypy 40 | additional_dependencies: ["aiohappyeyeballs>=2.3.0", "noiseprotocol>=0.3.1,<1.0", "cryptography>=43.0.0", "zeroconf>=0.143.0,<1.0"] 41 | files: ^((aioesphomeapi)/.+)?[^/]+\.(py)$ 42 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "label": "Generate proto files", 6 | "type": "shell", 7 | "command": "script/gen-protoc", 8 | "group": { 9 | "kind": "build", 10 | "isDefault": true 11 | }, 12 | "presentation": { 13 | "reveal": "never", 14 | "close": true, 15 | "panel": "new" 16 | }, 17 | "problemMatcher": [] 18 | } 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | ENV PROTOC_VERSION 25.5 4 | 5 | ARG TARGETARCH 6 | 7 | RUN if [ "$TARGETARCH" = "amd64" ]; then \ 8 | arch="x86_64"; \ 9 | elif [ "$TARGETARCH" = "arm64" ]; then \ 10 | arch="aarch_64"; \ 11 | else \ 12 | exit 1; \ 13 | fi \ 14 | && curl -o /tmp/protoc.zip -L \ 15 | https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${arch}.zip \ 16 | && unzip /tmp/protoc.zip -d /usr -x readme.txt \ 17 | && rm /tmp/protoc.zip 18 | 19 | RUN useradd -ms /bin/bash esphome 20 | 21 | USER esphome 22 | 23 | WORKDIR /aioesphomeapi 24 | 25 | COPY requirements/test.txt ./requirements_test.txt 26 | 27 | RUN pip3 install -r requirements_test.txt 28 | 29 | CMD ["script/gen-protoc"] 30 | 31 | LABEL \ 32 | org.opencontainers.image.title="aioesphomeapi protobuf generator" \ 33 | org.opencontainers.image.description="An image to help with ESPHomes aioesphomeapi protobuf generation" \ 34 | org.opencontainers.image.vendor="ESPHome" \ 35 | org.opencontainers.image.licenses="MIT" \ 36 | org.opencontainers.image.url="https://esphome.io" \ 37 | org.opencontainers.image.source="https://github.com/esphome/aioesphomeapi" \ 38 | org.opencontainers.image.documentation="https://github.com/esphome/aioesphomeapi/blob/main/README.md" 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Otto Winter 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # Maintaining Notes 2 | 3 | Releases are automatically drafted by [release-drafter](https://github.com/release-drafter/release-drafter), the next version number is automatically computed by the labels of PRs in the release. 4 | 5 | See also .github/release-drafter.yml, if this label is in any PR, then the version change is marked that type of version change: 6 | 7 | - major release (+1.0.0): breaking-change, major 8 | - minor release (+0.1.0): minor, new-feature 9 | - patch (+0.0.1): this is the default release type 10 | 11 | Before creating a release: Check the latest commit passes continuous integration. 12 | 13 | When the release button on the draft is clicked, GitHub Actions will publish the release to PyPi. 14 | 15 | After any push to the main branch, the "protoc-update" workflow is run which updates the generated python protobuf files. This is to ensure that if a contributor has a newer protoc version installed than the protobuf python package, we won't run into any issues. 16 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include requirements/base.txt 3 | include aioesphomeapi/py.typed 4 | global-exclude *.c 5 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | aioesphomeapi 2 | ============= 3 | 4 | .. image:: https://github.com/esphome/aioesphomeapi/workflows/CI/badge.svg 5 | :target: https://github.com/esphome/aioesphomeapi?query=workflow%3ACI+branch%3Amain 6 | 7 | .. image:: https://img.shields.io/pypi/v/aioesphomeapi.svg 8 | :target: https://pypi.python.org/pypi/aioesphomeapi 9 | 10 | .. image:: https://codecov.io/gh/esphome/aioesphomeapi/branch/main/graph/badge.svg 11 | :target: https://app.codecov.io/gh/esphome/aioesphomeapi/tree/main 12 | 13 | .. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json 14 | :target: https://codspeed.io/esphome/aioesphomeapi 15 | 16 | ``aioesphomeapi`` allows you to interact with devices flashed with `ESPHome `_. 17 | 18 | Installation 19 | ------------ 20 | 21 | The module is available from the `Python Package Index `_. 22 | 23 | .. code:: bash 24 | 25 | $ pip3 install aioesphomeapi 26 | 27 | An optional cython extension is available for better performance, and the module will try to build it automatically. 28 | 29 | The extension requires a C compiler and Python development headers. The module will fall back to the pure Python implementation if they are unavailable. 30 | 31 | Building the extension can be forcefully disabled by setting the environment variable ``SKIP_CYTHON`` to ``1``. 32 | 33 | Usage 34 | ----- 35 | 36 | It's required that you enable the `Native API `_ component for the device. 37 | 38 | .. code:: yaml 39 | 40 | # Example configuration entry 41 | api: 42 | password: 'MyPassword' 43 | 44 | Check the output to get the local address of the device or use the ``name:``under ``esphome:`` from the device configuration. 45 | 46 | .. code:: bash 47 | 48 | [17:56:38][C][api:095]: API Server: 49 | [17:56:38][C][api:096]: Address: api_test.local:6053 50 | 51 | 52 | The sample code below will connect to the device and retrieve details. 53 | 54 | .. code:: python 55 | 56 | import aioesphomeapi 57 | import asyncio 58 | 59 | async def main(): 60 | """Connect to an ESPHome device and get details.""" 61 | 62 | # Establish connection 63 | api = aioesphomeapi.APIClient("api_test.local", 6053, "MyPassword") 64 | await api.connect(login=True) 65 | 66 | # Get API version of the device's firmware 67 | print(api.api_version) 68 | 69 | # Show device details 70 | device_info = await api.device_info() 71 | print(device_info) 72 | 73 | # List all entities of the device 74 | entities = await api.list_entities_services() 75 | print(entities) 76 | 77 | loop = asyncio.get_event_loop() 78 | loop.run_until_complete(main()) 79 | 80 | Subscribe to state changes of an ESPHome device. 81 | 82 | .. code:: python 83 | 84 | import aioesphomeapi 85 | import asyncio 86 | 87 | async def main(): 88 | """Connect to an ESPHome device and wait for state changes.""" 89 | cli = aioesphomeapi.APIClient("api_test.local", 6053, "MyPassword") 90 | 91 | await cli.connect(login=True) 92 | 93 | def change_callback(state): 94 | """Print the state changes of the device..""" 95 | print(state) 96 | 97 | # Subscribe to the state changes 98 | cli.subscribe_states(change_callback) 99 | 100 | loop = asyncio.get_event_loop() 101 | try: 102 | asyncio.ensure_future(main()) 103 | loop.run_forever() 104 | except KeyboardInterrupt: 105 | pass 106 | finally: 107 | loop.close() 108 | 109 | Other examples: 110 | 111 | - `Camera `_ 112 | - `Async print `_ 113 | - `Simple print `_ 114 | - `InfluxDB `_ 115 | 116 | Development 117 | ----------- 118 | 119 | For development is recommended to use a Python virtual environment (``venv``). 120 | 121 | .. code:: bash 122 | 123 | # Setup virtualenv (optional) 124 | $ python3 -m venv . 125 | $ source bin/activate 126 | # Install aioesphomeapi and development depenencies 127 | $ pip3 install -e . 128 | $ pip3 install -r requirements/test.txt 129 | 130 | # Run linters & test 131 | $ script/lint 132 | # Update protobuf _pb2.py definitions (requires a protobuf compiler installation) 133 | $ script/gen-protoc 134 | 135 | A cli tool is also available for watching logs: 136 | 137 | .. code:: bash 138 | 139 | aioesphomeapi-logs --help 140 | 141 | A cli tool is also available to discover devices: 142 | 143 | .. code:: bash 144 | 145 | aioesphomeapi-discover --help 146 | 147 | 148 | License 149 | ------- 150 | 151 | ``aioesphomeapi`` is licensed under MIT, for more details check LICENSE. 152 | -------------------------------------------------------------------------------- /aioesphomeapi/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .api_pb2 import ( # type: ignore[attr-defined] # noqa: F401 3 | BluetoothLERawAdvertisement, 4 | BluetoothLERawAdvertisementsResponse, 5 | ) 6 | from .ble_defs import ESP_CONNECTION_ERROR_DESCRIPTION, BLEConnectionError 7 | from .client import APIClient 8 | from .connection import APIConnection, ConnectionParams 9 | from .core import ( 10 | ESPHOME_GATT_ERRORS, 11 | MESSAGE_TYPE_TO_PROTO, 12 | APIConnectionError, 13 | EncryptionPlaintextAPIError, 14 | BadNameAPIError, 15 | BluetoothConnectionDroppedError, 16 | HandshakeAPIError, 17 | InvalidAuthAPIError, 18 | InvalidEncryptionKeyAPIError, 19 | ProtocolAPIError, 20 | RequiresEncryptionAPIError, 21 | ResolveAPIError, 22 | EncryptionHelloAPIError, 23 | SocketAPIError, 24 | BadMACAddressAPIError, 25 | ) 26 | from .model import * 27 | from .reconnect_logic import ReconnectLogic 28 | from .log_parser import parse_log_message, LogParser 29 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/base.pxd: -------------------------------------------------------------------------------- 1 | 2 | import cython 3 | 4 | from ..connection cimport APIConnection 5 | 6 | 7 | cdef bint TYPE_CHECKING 8 | 9 | cdef class APIFrameHelper: 10 | 11 | cdef object _loop 12 | cdef APIConnection _connection 13 | cdef object _transport 14 | cdef public object _writelines 15 | cdef public object ready_future 16 | cdef bytes _buffer 17 | cdef unsigned int _buffer_len 18 | cdef unsigned int _pos 19 | cdef object _client_info 20 | cdef str _log_name 21 | 22 | cpdef set_log_name(self, str log_name) 23 | 24 | @cython.locals( 25 | original_pos="unsigned int", 26 | new_pos="unsigned int", 27 | cstr="const unsigned char *" 28 | ) 29 | cdef bytes _read(self, int length) 30 | 31 | @cython.locals(bytes_data=bytes) 32 | cdef void _add_to_buffer(self, object data) except * 33 | 34 | @cython.locals(end_of_frame_pos="unsigned int", cstr="const unsigned char *") 35 | cdef void _remove_from_buffer(self) except * 36 | 37 | cpdef void write_packets(self, list packets, bint debug_enabled) except * 38 | 39 | cdef void _write_bytes(self, object data, bint debug_enabled) except * 40 | 41 | cdef void _handle_error_and_close(self, Exception exc) except * 42 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import abstractmethod 4 | import asyncio 5 | from collections.abc import Iterable 6 | import logging 7 | from typing import TYPE_CHECKING, Callable, cast 8 | 9 | from ..core import SocketClosedAPIError 10 | 11 | if TYPE_CHECKING: 12 | from ..connection import APIConnection 13 | 14 | _LOGGER = logging.getLogger(__name__) 15 | 16 | SOCKET_ERRORS = ( 17 | ConnectionResetError, 18 | asyncio.IncompleteReadError, 19 | OSError, 20 | TimeoutError, 21 | ) 22 | 23 | 24 | _int = int 25 | _bytes = bytes 26 | 27 | 28 | class APIFrameHelper: 29 | """Helper class to handle the API frame protocol.""" 30 | 31 | __slots__ = ( 32 | "_buffer", 33 | "_buffer_len", 34 | "_client_info", 35 | "_connection", 36 | "_log_name", 37 | "_loop", 38 | "_pos", 39 | "_transport", 40 | "_writelines", 41 | "ready_future", 42 | ) 43 | 44 | def __init__( 45 | self, 46 | connection: APIConnection, 47 | client_info: str, 48 | log_name: str, 49 | ) -> None: 50 | """Initialize the API frame helper.""" 51 | loop = asyncio.get_running_loop() 52 | self._loop = loop 53 | self._connection = connection 54 | self._transport: asyncio.Transport | None = None 55 | self._writelines: ( 56 | None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None]) 57 | ) = None 58 | self.ready_future = self._loop.create_future() 59 | self._buffer: bytes | None = None 60 | self._buffer_len = 0 61 | self._pos = 0 62 | self._client_info = client_info 63 | self._log_name = log_name 64 | 65 | def set_log_name(self, log_name: str) -> None: 66 | """Set the log name.""" 67 | self._log_name = log_name 68 | 69 | def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None: 70 | if not self.ready_future.done(): 71 | self.ready_future.set_exception(exc) 72 | 73 | def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None: 74 | """Add data to the buffer.""" 75 | # Protractor sends a bytearray, so we need to convert it to bytes 76 | # https://github.com/esphome/issues/issues/5117 77 | # type(data) should not be isinstance(data, bytes) because we want to 78 | # to explicitly check for bytes and not for subclasses of bytes 79 | bytes_data = bytes(data) if type(data) is not bytes else data 80 | if self._buffer_len == 0: 81 | # This is the best case scenario, we don't have to copy the data 82 | # and can just use the buffer directly. This is the most common 83 | # case as well. 84 | self._buffer = bytes_data 85 | else: 86 | if TYPE_CHECKING: 87 | assert self._buffer is not None, "Buffer should be set" 88 | # This is the worst case scenario, we have to copy the bytes_data 89 | # and can't just use the buffer directly. This is also very 90 | # uncommon since we usually read the entire frame at once. 91 | self._buffer += bytes_data 92 | self._buffer_len += len(bytes_data) 93 | 94 | def _remove_from_buffer(self) -> None: 95 | """Remove data from the buffer.""" 96 | end_of_frame_pos = self._pos 97 | self._buffer_len -= end_of_frame_pos 98 | if self._buffer_len == 0: 99 | # This is the best case scenario, we can just set the buffer to None 100 | # and don't have to copy the data. This is the most common case as well. 101 | self._buffer = None 102 | return 103 | if TYPE_CHECKING: 104 | assert self._buffer is not None, "Buffer should be set" 105 | # This is the worst case scenario, we have to copy the data 106 | # and can't just use the buffer directly. This should only happen 107 | # when we read multiple frames at once because the event loop 108 | # is blocked and we cannot pull the data out of the buffer fast enough. 109 | cstr = self._buffer 110 | # Important: we must use the explicit length for the slice 111 | # since Cython will stop at any '\0' character if we don't 112 | self._buffer = cstr[end_of_frame_pos : self._buffer_len + end_of_frame_pos] 113 | 114 | def _read(self, length: _int) -> bytes | None: 115 | """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" 116 | new_pos = self._pos + length 117 | if self._buffer_len < new_pos: 118 | return None 119 | original_pos = self._pos 120 | self._pos = new_pos 121 | if TYPE_CHECKING: 122 | assert self._buffer is not None, "Buffer should be set" 123 | cstr = self._buffer 124 | # Important: we must keep the bounds check (self._buffer_len < new_pos) 125 | # above to verify we never try to read past the end of the buffer 126 | return cstr[original_pos:new_pos] 127 | 128 | @abstractmethod 129 | def write_packets( 130 | self, packets: list[tuple[int, bytes]], debug_enabled: bool 131 | ) -> None: 132 | """Write a packets to the socket. 133 | 134 | Packets are in the format of tuple[protobuf_type, protobuf_data] 135 | """ 136 | 137 | def connection_made(self, transport: asyncio.BaseTransport) -> None: 138 | """Handle a new connection.""" 139 | self._transport = cast(asyncio.Transport, transport) 140 | self._writelines = self._transport.writelines 141 | 142 | def _handle_error_and_close(self, exc: Exception) -> None: 143 | """Handle an error and close the connection. 144 | 145 | May not be overridden by subclasses. 146 | """ 147 | self._handle_error(exc) 148 | self.close() 149 | 150 | def _handle_error(self, exc: Exception) -> None: 151 | """Handle an error. 152 | 153 | May be overridden by subclasses. 154 | """ 155 | self._set_ready_future_exception(exc) 156 | self._connection.report_fatal_error(exc) 157 | 158 | def connection_lost(self, exc: Exception | None) -> None: 159 | """Handle the connection being lost.""" 160 | self._handle_error( 161 | exc or SocketClosedAPIError(f"{self._log_name}: Connection lost") 162 | ) 163 | 164 | def eof_received(self) -> bool | None: 165 | """Handle EOF received.""" 166 | self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received")) 167 | return False 168 | 169 | def close(self) -> None: 170 | """Close the connection.""" 171 | if self._transport: 172 | self._transport.close() 173 | self._transport = None 174 | self._writelines = None 175 | 176 | def pause_writing(self) -> None: 177 | """Stub.""" 178 | 179 | def resume_writing(self) -> None: 180 | """Stub.""" 181 | 182 | def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None: 183 | """Write bytes to the socket.""" 184 | if debug_enabled: 185 | _LOGGER.debug( 186 | "%s: Sending frame: [%s]", self._log_name, b"".join(data).hex() 187 | ) 188 | 189 | if TYPE_CHECKING: 190 | assert self._writelines is not None, "Writer is not set" 191 | 192 | self._writelines(data) 193 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/noise.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | 3 | from ..connection cimport APIConnection 4 | from .base cimport APIFrameHelper 5 | from .noise_encryption cimport EncryptCipher, DecryptCipher 6 | from .packets cimport make_noise_packets 7 | 8 | cdef bint TYPE_CHECKING 9 | 10 | cdef unsigned int NOISE_STATE_HELLO 11 | cdef unsigned int NOISE_STATE_HANDSHAKE 12 | cdef unsigned int NOISE_STATE_READY 13 | cdef unsigned int NOISE_STATE_CLOSED 14 | 15 | cdef bytes NOISE_HELLO 16 | cdef object InvalidTag 17 | cdef object ESPHOME_NOISE_BACKEND 18 | 19 | cdef class APINoiseFrameHelper(APIFrameHelper): 20 | 21 | cdef object _noise_psk 22 | cdef str _expected_name 23 | cdef str _expected_mac 24 | cdef unsigned int _state 25 | cdef str _server_mac 26 | cdef str _server_name 27 | cdef object _proto 28 | cdef EncryptCipher _encrypt_cipher 29 | cdef DecryptCipher _decrypt_cipher 30 | 31 | @cython.locals( 32 | header=bytes, 33 | preamble="unsigned char", 34 | header="const unsigned char *" 35 | ) 36 | cpdef void data_received(self, object data) except * 37 | 38 | @cython.locals( 39 | msg=bytes, 40 | msg_type="unsigned int", 41 | payload=bytes, 42 | msg_length=Py_ssize_t, 43 | msg_cstr="const unsigned char *", 44 | ) 45 | cdef void _handle_frame(self, bytes frame) except * 46 | 47 | @cython.locals( 48 | chosen_proto=char, 49 | server_name_i=int, 50 | mac_address_i=int, 51 | mac_address=str, 52 | server_name=str, 53 | ) 54 | cdef void _handle_hello(self, bytes server_hello) except * 55 | 56 | cdef void _handle_handshake(self, bytes msg) except * 57 | 58 | cdef void _handle_closed(self, bytes frame) except * 59 | 60 | @cython.locals(handshake_frame=bytearray, frame_len="unsigned int") 61 | cdef void _send_hello_handshake(self) except * 62 | 63 | cdef void _setup_proto(self) except * 64 | 65 | @cython.locals(psk_bytes=bytes) 66 | cdef _decode_noise_psk(self) 67 | 68 | cpdef void write_packets(self, list packets, bint debug_enabled) except * 69 | 70 | cdef _error_on_incorrect_preamble(self, bytes msg) 71 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/noise.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import binascii 5 | import logging 6 | from typing import TYPE_CHECKING 7 | 8 | from cryptography.exceptions import InvalidTag 9 | from noise.connection import NoiseConnection 10 | 11 | from ..core import ( 12 | APIConnectionError, 13 | BadMACAddressAPIError, 14 | BadNameAPIError, 15 | EncryptionErrorAPIError, 16 | EncryptionHelloAPIError, 17 | EncryptionPlaintextAPIError, 18 | HandshakeAPIError, 19 | InvalidEncryptionKeyAPIError, 20 | ProtocolAPIError, 21 | ) 22 | from .base import _LOGGER, APIFrameHelper 23 | from .noise_encryption import ESPHOME_NOISE_BACKEND, DecryptCipher, EncryptCipher 24 | from .packets import make_noise_packets 25 | 26 | if TYPE_CHECKING: 27 | from ..connection import APIConnection 28 | 29 | 30 | # This is effectively an enum but we don't want to use an enum 31 | # because we have a simple dispatch in the data_received method 32 | # that would be more complicated with an enum and we want to add 33 | # cdefs for each different state so we have a good test for each 34 | # state receiving data since we found that the protractor event 35 | # loop will send use a bytearray instead of bytes was not handled 36 | # correctly. 37 | NOISE_STATE_HELLO = 1 38 | NOISE_STATE_HANDSHAKE = 2 39 | NOISE_STATE_READY = 3 40 | NOISE_STATE_CLOSED = 4 41 | 42 | 43 | NOISE_HELLO = b"\x01\x00\x00" 44 | 45 | int_ = int 46 | 47 | 48 | class APINoiseFrameHelper(APIFrameHelper): 49 | """Frame helper for noise encrypted connections.""" 50 | 51 | __slots__ = ( 52 | "_decrypt_cipher", 53 | "_encrypt_cipher", 54 | "_expected_mac", 55 | "_expected_name", 56 | "_noise_psk", 57 | "_proto", 58 | "_server_mac", 59 | "_server_name", 60 | "_state", 61 | ) 62 | 63 | def __init__( 64 | self, 65 | connection: APIConnection, 66 | noise_psk: str, 67 | expected_name: str | None, 68 | expected_mac: str | None, 69 | client_info: str, 70 | log_name: str, 71 | ) -> None: 72 | """Initialize the API frame helper.""" 73 | super().__init__(connection, client_info, log_name) 74 | self._noise_psk = noise_psk 75 | self._expected_mac = expected_mac 76 | self._expected_name = expected_name 77 | self._state = NOISE_STATE_HELLO 78 | self._server_name: str | None = None 79 | self._server_mac: str | None = None 80 | self._encrypt_cipher: EncryptCipher | None = None 81 | self._decrypt_cipher: DecryptCipher | None = None 82 | self._setup_proto() 83 | 84 | def close(self) -> None: 85 | """Close the connection.""" 86 | # Make sure we set the ready event if its not already set 87 | # so that we don't block forever on the ready event if we 88 | # are waiting for the handshake to complete. 89 | self._set_ready_future_exception( 90 | APIConnectionError(f"{self._log_name}: Connection closed") 91 | ) 92 | self._state = NOISE_STATE_CLOSED 93 | super().close() 94 | 95 | def _handle_error(self, exc: Exception) -> None: 96 | """Handle an error, and provide a good message when during hello.""" 97 | if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError): 98 | original_exc: Exception = exc 99 | exc = EncryptionHelloAPIError( 100 | f"{self._log_name}: The connection dropped immediately after encrypted hello; " 101 | "Try enabling encryption on the device or turning off " 102 | f"encryption on the client ({self._client_info})" 103 | ) 104 | exc.__cause__ = original_exc 105 | super()._handle_error(exc) 106 | 107 | def connection_made(self, transport: asyncio.BaseTransport) -> None: 108 | """Handle a new connection.""" 109 | super().connection_made(transport) 110 | self._send_hello_handshake() 111 | 112 | def data_received(self, data: bytes | bytearray | memoryview) -> None: 113 | self._add_to_buffer(data) 114 | # Message header is 3 bytes 115 | while self._buffer_len >= 3: 116 | if TYPE_CHECKING: 117 | assert self._buffer is not None, "Buffer should be set" 118 | self._pos = 3 119 | header = self._buffer 120 | preamble = header[0] 121 | if preamble != 0x01: 122 | if preamble == 0x00: 123 | self._handle_error_and_close( 124 | EncryptionPlaintextAPIError( 125 | f"{self._log_name}: The device is using plaintext protocol; " 126 | "Try enabling encryption on the device or turning off " 127 | f"encryption on the client ({self._client_info})" 128 | ) 129 | ) 130 | else: 131 | self._handle_error_and_close( 132 | ProtocolAPIError( 133 | f"{self._log_name}: Marker byte invalid: {preamble}" 134 | ) 135 | ) 136 | return 137 | if (frame := self._read((header[1] << 8) | header[2])) is None: 138 | # The complete frame is not yet available, wait for more data 139 | # to arrive before continuing, since callback_packet has not 140 | # been called yet the buffer will not be cleared and the next 141 | # call to data_received will continue processing the packet 142 | # at the start of the frame. 143 | return 144 | 145 | # asyncio already runs data_received in a try block 146 | # which will call connection_lost if an exception is raised 147 | if self._state == NOISE_STATE_READY: 148 | self._handle_frame(frame) 149 | elif self._state == NOISE_STATE_HELLO: 150 | self._handle_hello(frame) 151 | elif self._state == NOISE_STATE_HANDSHAKE: 152 | self._handle_handshake(frame) 153 | else: 154 | self._handle_closed(frame) 155 | 156 | self._remove_from_buffer() 157 | 158 | def _send_hello_handshake(self) -> None: 159 | """Send a ClientHello to the server.""" 160 | handshake_frame = self._proto.write_message() 161 | frame_len = len(handshake_frame) + 1 162 | header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) 163 | self._write_bytes( 164 | (NOISE_HELLO, header, b"\x00", handshake_frame), 165 | _LOGGER.isEnabledFor(logging.DEBUG), 166 | ) 167 | 168 | def _handle_hello(self, server_hello: bytes) -> None: 169 | """Perform the handshake with the server.""" 170 | if not server_hello: 171 | self._handle_error_and_close( 172 | HandshakeAPIError(f"{self._log_name}: ServerHello is empty") 173 | ) 174 | return 175 | 176 | # First byte of server hello is the protocol the server chose 177 | # for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256) 178 | # exists. 179 | chosen_proto = server_hello[0] 180 | if chosen_proto != 0x01: 181 | self._handle_error_and_close( 182 | HandshakeAPIError( 183 | f"{self._log_name}: Unknown protocol selected by client {chosen_proto}" 184 | ) 185 | ) 186 | return 187 | 188 | # Check name matches expected name (for noise sessions, this is done 189 | # during hello phase before a connection is set up) 190 | # Server name is encoded as a string followed by a zero byte after the chosen proto byte 191 | server_name_i = server_hello.find(b"\0", 1) 192 | if server_name_i != -1: 193 | # server name found, this extension was added in 2022.2 194 | server_name = server_hello[1:server_name_i].decode() 195 | self._server_name = server_name 196 | 197 | if self._expected_name is not None and self._expected_name != server_name: 198 | self._handle_error_and_close( 199 | BadNameAPIError( 200 | f"{self._log_name}: Server sent a different name '{server_name}'", 201 | server_name, 202 | ) 203 | ) 204 | return 205 | 206 | mac_address_i = server_hello.find(b"\0", server_name_i + 1) 207 | if mac_address_i != -1: 208 | # mac address found, this extension was added in 2025.4 209 | mac_address = server_hello[server_name_i + 1 : mac_address_i].decode() 210 | self._server_mac = mac_address 211 | if self._expected_mac is not None and self._expected_mac != mac_address: 212 | self._handle_error_and_close( 213 | BadMACAddressAPIError( 214 | f"{self._log_name}: Server sent a different mac '{mac_address}'", 215 | server_name, 216 | mac_address, 217 | ) 218 | ) 219 | return 220 | 221 | self._state = NOISE_STATE_HANDSHAKE 222 | 223 | def _decode_noise_psk(self) -> bytes: 224 | """Decode the given noise psk from base64 format to raw bytes.""" 225 | psk = self._noise_psk 226 | server_name = self._server_name 227 | server_mac = self._server_mac 228 | try: 229 | psk_bytes = binascii.a2b_base64(psk) 230 | except ValueError: 231 | raise InvalidEncryptionKeyAPIError( 232 | f"{self._log_name}: Malformed PSK `{psk}`, expected " 233 | "base64-encoded value", 234 | server_name, 235 | server_mac, 236 | ) 237 | if len(psk_bytes) != 32: 238 | raise InvalidEncryptionKeyAPIError( 239 | f"{self._log_name}:Malformed PSK `{psk}`, expected" 240 | f" 32-bytes of base64 data", 241 | server_name, 242 | server_mac, 243 | ) 244 | return psk_bytes 245 | 246 | def _setup_proto(self) -> None: 247 | """Set up the noise protocol.""" 248 | proto = NoiseConnection.from_name( 249 | b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND 250 | ) 251 | proto.set_as_initiator() 252 | proto.set_psks(self._decode_noise_psk()) 253 | proto.set_prologue(b"NoiseAPIInit\x00\x00") 254 | proto.start_handshake() 255 | self._proto = proto 256 | 257 | def _error_on_incorrect_preamble(self, msg: bytes) -> None: 258 | """Handle an incorrect preamble.""" 259 | explanation = msg[1:].decode() 260 | if explanation != "Handshake MAC failure": 261 | exc = HandshakeAPIError( 262 | f"{self._log_name}: Handshake failure: {explanation}" 263 | ) 264 | else: 265 | exc = InvalidEncryptionKeyAPIError( 266 | f"{self._log_name}: Invalid encryption key", 267 | self._server_name, 268 | self._server_mac, 269 | ) 270 | self._handle_error_and_close(exc) 271 | 272 | def _handle_handshake(self, msg: bytes) -> None: 273 | if msg[0] != 0: 274 | self._error_on_incorrect_preamble(msg) 275 | return 276 | self._proto.read_message(msg[1:]) 277 | self._state = NOISE_STATE_READY 278 | noise_protocol = self._proto.noise_protocol 279 | self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member 280 | self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member 281 | self.ready_future.set_result(None) 282 | 283 | def write_packets( 284 | self, packets: list[tuple[int, bytes]], debug_enabled: bool 285 | ) -> None: 286 | """Write a packets to the socket. 287 | 288 | Packets are in the format of tuple[protobuf_type, protobuf_data] 289 | """ 290 | if TYPE_CHECKING: 291 | assert self._encrypt_cipher is not None, "Handshake should be complete" 292 | self._write_bytes( 293 | make_noise_packets(packets, self._encrypt_cipher), debug_enabled 294 | ) 295 | 296 | def _handle_frame(self, frame: bytes) -> None: 297 | """Handle an incoming frame.""" 298 | if TYPE_CHECKING: 299 | assert self._decrypt_cipher is not None, "Handshake should be complete" 300 | try: 301 | msg = self._decrypt_cipher.decrypt(frame) 302 | except InvalidTag: 303 | # This shouldn't happen since we already checked the tag during handshake 304 | # but it could happen if the server sends a bad frame see 305 | # issue https://github.com/esphome/aioesphomeapi/issues/1044 306 | self._handle_error_and_close( 307 | EncryptionErrorAPIError( 308 | f"{self._log_name}: Encryption error", self._server_name 309 | ) 310 | ) 311 | return 312 | msg_length = len(msg) 313 | msg_cstr = msg 314 | if msg_length < 4: 315 | # Important: we must bound check msg_length to ensure we 316 | # do not read past the end of the message in the payload 317 | # slicing below 318 | self._handle_error_and_close( 319 | ProtocolAPIError( 320 | f"{self._log_name}: Decrypted message too short: {msg_length} bytes" 321 | ) 322 | ) 323 | return 324 | # Message layout is 325 | # 2 bytes: message type (0:type_high, 1:type_low) 326 | # 2 bytes: message length (2:length_high, 3:length_low) 327 | # - We ignore the message length field because we do not 328 | # trust the remote end to send the correct length 329 | # N bytes: message data (4:...) 330 | msg_type = (msg_cstr[0] << 8) | msg_cstr[1] 331 | # Important: we must explicitly use msg_length here since msg_cstr 332 | # is a cstring and Cython will stop at the first null byte if we 333 | # do not use msg_length 334 | payload = msg_cstr[4:msg_length] 335 | self._connection.process_packet(msg_type, payload) 336 | 337 | def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument 338 | """Handle a closed frame.""" 339 | self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed")) 340 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/noise_encryption.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | from libc.stdint cimport uint64_t 3 | from .pack cimport fast_pack_nonce 4 | 5 | cdef object PACK_NONCE 6 | 7 | cdef class EncryptCipher: 8 | 9 | cdef uint64_t _nonce 10 | cdef object _encrypt 11 | 12 | cpdef bytes encrypt(self, object frame) 13 | 14 | cdef class DecryptCipher: 15 | 16 | cdef uint64_t _nonce 17 | cdef object _decrypt 18 | 19 | cdef bytes decrypt(self, object frame) 20 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/noise_encryption.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from struct import Struct 5 | from typing import Any 6 | 7 | from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable 8 | from noise.backends.default import DefaultNoiseBackend 9 | from noise.backends.default.ciphers import ChaCha20Cipher, CryptographyCipher 10 | from noise.state import CipherState 11 | 12 | _bytes = bytes 13 | _int = int 14 | 15 | PACK_NONCE = partial(Struct(" type[ChaCha20Poly1305Reusable]: 30 | return ChaCha20Poly1305Reusable # type: ignore[no-any-return, unused-ignore] 31 | 32 | 33 | class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc] 34 | def __init__(self, *args: Any, **kwargs: Any) -> None: 35 | super().__init__(*args, **kwargs) 36 | self.ciphers["ChaChaPoly"] = ChaCha20CipherReuseable 37 | 38 | 39 | ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend() 40 | 41 | 42 | class EncryptCipher: 43 | """Wrapper around the ChaCha20Poly1305 cipher for encryption.""" 44 | 45 | __slots__ = ("_encrypt", "_nonce") 46 | 47 | def __init__(self, cipher_state: CipherState) -> None: 48 | """Initialize the cipher wrapper.""" 49 | crypto_cipher: CryptographyCipher = cipher_state.cipher 50 | cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher 51 | self._nonce: _int = cipher_state.n 52 | self._encrypt = cipher.encrypt 53 | 54 | def encrypt(self, data: _bytes) -> bytes: 55 | """Encrypt a frame.""" 56 | ciphertext = self._encrypt(fast_pack_nonce(self._nonce), data, None) 57 | self._nonce += 1 58 | return ciphertext # type: ignore[no-any-return, unused-ignore] 59 | 60 | 61 | class DecryptCipher: 62 | """Wrapper around the ChaCha20Poly1305 cipher for decryption.""" 63 | 64 | __slots__ = ("_decrypt", "_nonce") 65 | 66 | def __init__(self, cipher_state: CipherState) -> None: 67 | """Initialize the cipher wrapper.""" 68 | crypto_cipher: CryptographyCipher = cipher_state.cipher 69 | cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher 70 | self._nonce: _int = cipher_state.n 71 | self._decrypt = cipher.decrypt 72 | 73 | def decrypt(self, data: _bytes) -> bytes: 74 | """Decrypt a frame.""" 75 | plaintext = self._decrypt(fast_pack_nonce(self._nonce), data, None) 76 | self._nonce += 1 77 | return plaintext # type: ignore[no-any-return, unused-ignore] 78 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/pack.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | from libc.stdint cimport uint64_t 3 | 4 | cpdef bytes fast_pack_nonce(uint64_t q) 5 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/pack.pyx: -------------------------------------------------------------------------------- 1 | from cpython.bytes cimport PyBytes_FromStringAndSize 2 | from libc.stdint cimport uint64_t 3 | 4 | cpdef bytes fast_pack_nonce(uint64_t q): 5 | cdef: 6 | char buf[12] 7 | char *p = buf 8 | 9 | # First 4 bytes are zero 10 | p[0] = p[1] = p[2] = p[3] = 0 11 | 12 | # q (uint64_t) in little-endian 13 | p[4] = (q & 0xFF) 14 | p[5] = ((q >> 8) & 0xFF) 15 | p[6] = ((q >> 16) & 0xFF) 16 | p[7] = ((q >> 24) & 0xFF) 17 | p[8] = ((q >> 32) & 0xFF) 18 | p[9] = ((q >> 40) & 0xFF) 19 | p[10] = ((q >> 48) & 0xFF) 20 | p[11] = ((q >> 56) & 0xFF) 21 | 22 | return PyBytes_FromStringAndSize(buf, 12) 23 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/packets.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | 3 | from .noise_encryption cimport EncryptCipher 4 | 5 | cdef object varuint_to_bytes 6 | 7 | cpdef _varuint_to_bytes(int value) 8 | 9 | 10 | @cython.locals( 11 | type_="unsigned int", 12 | data=bytes, 13 | packet=tuple, 14 | type_=object 15 | ) 16 | cpdef list make_plain_text_packets(list packets) except * 17 | 18 | 19 | @cython.locals( 20 | type_="unsigned int", 21 | data=bytes, 22 | data_header=bytes, 23 | packet=tuple, 24 | data_len=Py_ssize_t, 25 | frame=bytes, 26 | frame_len=Py_ssize_t, 27 | ) 28 | cpdef list make_noise_packets(list packets, EncryptCipher encrypt_cipher) except * 29 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/packets.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from .noise_encryption import EncryptCipher 4 | 5 | _int = int 6 | 7 | 8 | def _varuint_to_bytes(value: _int) -> bytes: 9 | """Convert a varuint to bytes.""" 10 | if value <= 0x7F: 11 | return bytes((value,)) 12 | 13 | result = bytearray() 14 | while value: 15 | temp = value & 0x7F 16 | value >>= 7 17 | if value: 18 | result.append(temp | 0x80) 19 | else: 20 | result.append(temp) 21 | 22 | return bytes(result) 23 | 24 | 25 | _cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes) 26 | varuint_to_bytes = _cached_varuint_to_bytes 27 | 28 | 29 | def make_plain_text_packets(packets: list[tuple[int, bytes]]) -> list[bytes]: 30 | """Make a list of plain text packet.""" 31 | out: list[bytes] = [] 32 | for packet in packets: 33 | type_: int = packet[0] 34 | data: bytes = packet[1] 35 | out.append(b"\0") 36 | out.append(varuint_to_bytes(len(data))) 37 | out.append(varuint_to_bytes(type_)) 38 | if data: 39 | out.append(data) 40 | return out 41 | 42 | 43 | def make_noise_packets( 44 | packets: list[tuple[int, bytes]], encrypt_cipher: EncryptCipher 45 | ) -> list[bytes]: 46 | """Make a list of noise packet.""" 47 | out: list[bytes] = [] 48 | for packet in packets: 49 | type_: int = packet[0] 50 | data: bytes = packet[1] 51 | data_len = len(data) 52 | data_header = bytes( 53 | ( 54 | (type_ >> 8) & 0xFF, 55 | type_ & 0xFF, 56 | (data_len >> 8) & 0xFF, 57 | data_len & 0xFF, 58 | ) 59 | ) 60 | frame = encrypt_cipher.encrypt(data_header + data) 61 | frame_len = len(frame) 62 | header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) 63 | out.append(header) 64 | out.append(frame) 65 | return out 66 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/plain_text.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | 3 | from ..connection cimport APIConnection 4 | from .base cimport APIFrameHelper 5 | from .packets cimport make_plain_text_packets 6 | 7 | cdef bint TYPE_CHECKING 8 | 9 | 10 | cdef class APIPlaintextFrameHelper(APIFrameHelper): 11 | 12 | cpdef void data_received(self, object data) except * 13 | 14 | cdef void _error_on_incorrect_preamble(self, int preamble) except * 15 | 16 | @cython.locals( 17 | result="unsigned int", 18 | bitpos="unsigned int", 19 | cstr="const unsigned char *", 20 | val="unsigned char", 21 | current_pos="unsigned int" 22 | ) 23 | cdef int _read_varuint(self) 24 | 25 | cpdef void write_packets(self, list packets, bint debug_enabled) except * 26 | -------------------------------------------------------------------------------- /aioesphomeapi/_frame_helper/plain_text.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from typing import TYPE_CHECKING 5 | 6 | from ..core import ProtocolAPIError, RequiresEncryptionAPIError 7 | from .base import APIFrameHelper 8 | from .packets import make_plain_text_packets 9 | 10 | _int = int 11 | 12 | 13 | class APIPlaintextFrameHelper(APIFrameHelper): 14 | """Frame helper for plaintext API connections.""" 15 | 16 | def connection_made(self, transport: asyncio.BaseTransport) -> None: 17 | """Handle a new connection.""" 18 | super().connection_made(transport) 19 | self.ready_future.set_result(None) 20 | 21 | def write_packets( 22 | self, packets: list[tuple[int, bytes]], debug_enabled: bool 23 | ) -> None: 24 | """Write a packets to the socket. 25 | 26 | Packets are in the format of tuple[protobuf_type, protobuf_data] 27 | 28 | The entire packet must be written in a single call. 29 | """ 30 | self._write_bytes(make_plain_text_packets(packets), debug_enabled) 31 | 32 | def _read_varuint(self) -> _int: 33 | """Read a varuint from the buffer or -1 if the buffer runs out of bytes.""" 34 | if TYPE_CHECKING: 35 | assert self._buffer is not None, "Buffer should be set" 36 | result = 0 37 | bitpos = 0 38 | cstr = self._buffer 39 | while self._buffer_len > self._pos: 40 | val = cstr[self._pos] 41 | self._pos += 1 42 | result |= (val & 0x7F) << bitpos 43 | if (val & 0x80) == 0: 44 | return result 45 | bitpos += 7 46 | return -1 47 | 48 | def data_received(self, data: bytes | bytearray | memoryview) -> None: 49 | self._add_to_buffer(data) 50 | # Message header is at least 3 bytes, empty length allowed 51 | while self._buffer_len >= 3: 52 | self._pos = 0 53 | # Read preamble, which should always 0x00 54 | if (preamble := self._read_varuint()) != 0x00: 55 | self._error_on_incorrect_preamble(preamble) 56 | return 57 | if (length := self._read_varuint()) == -1: 58 | return 59 | if (msg_type := self._read_varuint()) == -1: 60 | return 61 | 62 | if length == 0: 63 | self._remove_from_buffer() 64 | self._connection.process_packet(msg_type, b"") 65 | continue 66 | 67 | # The packet data is not yet available, wait for more data 68 | # to arrive before continuing, since callback_packet has not 69 | # been called yet the buffer will not be cleared and the next 70 | # call to data_received will continue processing the packet 71 | # at the start of the frame. 72 | if (packet_data := self._read(length)) is None: 73 | return 74 | self._remove_from_buffer() 75 | self._connection.process_packet(msg_type, packet_data) 76 | # If we have more data, continue processing 77 | 78 | def _error_on_incorrect_preamble(self, preamble: _int) -> None: 79 | """Handle an incorrect preamble.""" 80 | if preamble == 0x01: 81 | self._handle_error_and_close( 82 | RequiresEncryptionAPIError( 83 | f"{self._log_name}: Connection requires encryption" 84 | ) 85 | ) 86 | return 87 | self._handle_error_and_close( 88 | ProtocolAPIError(f"{self._log_name}: Invalid preamble {preamble:02x}") 89 | ) 90 | -------------------------------------------------------------------------------- /aioesphomeapi/api_options.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | import "google/protobuf/descriptor.proto"; 3 | 4 | 5 | enum APISourceType { 6 | SOURCE_BOTH = 0; 7 | SOURCE_SERVER = 1; 8 | SOURCE_CLIENT = 2; 9 | } 10 | 11 | message void {} 12 | 13 | extend google.protobuf.MethodOptions { 14 | optional bool needs_setup_connection = 1038 [default=true]; 15 | optional bool needs_authentication = 1039 [default=true]; 16 | } 17 | 18 | extend google.protobuf.MessageOptions { 19 | optional uint32 id = 1036 [default=0]; 20 | optional APISourceType source = 1037 [default=SOURCE_BOTH]; 21 | optional string ifdef = 1038; 22 | optional bool log = 1039 [default=true]; 23 | optional bool no_delay = 1040 [default=false]; 24 | } 25 | -------------------------------------------------------------------------------- /aioesphomeapi/api_options_pb2.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | # -*- coding: utf-8 -*- 3 | # Generated by the protocol buffer compiler. DO NOT EDIT! 4 | # source: api_options.proto 5 | # Protobuf Python Version: 4.25.5 6 | """Generated protocol buffer code.""" 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import descriptor_pool as _descriptor_pool 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf.internal import builder as _builder 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse') 20 | 21 | _globals = globals() 22 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 23 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'api_options_pb2', _globals) 24 | if _descriptor._USE_C_DESCRIPTORS == False: 25 | DESCRIPTOR._options = None 26 | _globals['_APISOURCETYPE']._serialized_start=63 27 | _globals['_APISOURCETYPE']._serialized_end=133 28 | _globals['_VOID']._serialized_start=55 29 | _globals['_VOID']._serialized_end=61 30 | # @@protoc_insertion_point(module_scope) 31 | -------------------------------------------------------------------------------- /aioesphomeapi/ble_defs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import IntEnum 4 | 5 | 6 | class BLEConnectionError(IntEnum): 7 | """BLE Connection Error.""" 8 | 9 | ESP_GATT_CONN_UNKNOWN = 0 10 | ESP_GATT_CONN_L2C_FAILURE = 1 11 | ESP_GATT_CONN_TIMEOUT = 0x08 12 | ESP_GATT_CONN_TERMINATE_PEER_USER = 0x13 13 | ESP_GATT_CONN_TERMINATE_LOCAL_HOST = 0x16 14 | ESP_GATT_CONN_FAIL_ESTABLISH = 0x3E 15 | ESP_GATT_CONN_LMP_TIMEOUT = 0x22 16 | ESP_GATT_ERROR = 0x85 17 | ESP_GATT_CONN_CONN_CANCEL = 0x0100 18 | ESP_GATT_CONN_NONE = 0x0101 19 | 20 | 21 | ESP_CONNECTION_ERROR_DESCRIPTION = { 22 | BLEConnectionError.ESP_GATT_CONN_UNKNOWN: "Connection failed for unknown reason", 23 | BLEConnectionError.ESP_GATT_CONN_L2C_FAILURE: "Connection failed due to L2CAP failure", 24 | BLEConnectionError.ESP_GATT_CONN_TIMEOUT: "Connection failed due to timeout", 25 | BLEConnectionError.ESP_GATT_CONN_TERMINATE_PEER_USER: "Connection terminated by peer user", 26 | BLEConnectionError.ESP_GATT_CONN_TERMINATE_LOCAL_HOST: "Connection terminated by local host", 27 | BLEConnectionError.ESP_GATT_CONN_FAIL_ESTABLISH: "Connection failed to establish", 28 | BLEConnectionError.ESP_GATT_CONN_LMP_TIMEOUT: "Connection failed due to LMP response timeout", 29 | BLEConnectionError.ESP_GATT_ERROR: "Connection failed due to GATT operation error", 30 | BLEConnectionError.ESP_GATT_CONN_CONN_CANCEL: "Connection cancelled", 31 | BLEConnectionError.ESP_GATT_CONN_NONE: "No connection to cancel", 32 | } 33 | -------------------------------------------------------------------------------- /aioesphomeapi/client_base.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | 3 | from ._frame_helper.base cimport APIFrameHelper 4 | from ._frame_helper.noise cimport APINoiseFrameHelper 5 | from ._frame_helper.plain_text cimport APIPlaintextFrameHelper 6 | from .connection cimport APIConnection, ConnectionParams 7 | 8 | 9 | cdef object create_eager_task 10 | cdef object APIConnectionError 11 | 12 | cdef dict SUBSCRIBE_STATES_RESPONSE_TYPES 13 | 14 | cdef bint TYPE_CHECKING 15 | 16 | cdef object CameraImageResponse, CameraState 17 | 18 | cdef object HomeassistantServiceCall 19 | 20 | cdef object BluetoothLEAdvertisement 21 | 22 | cdef object BluetoothDeviceConnectionResponse 23 | 24 | cdef str _stringify_or_none(object value) 25 | 26 | cdef class APIClientBase: 27 | 28 | cdef public set _background_tasks 29 | cdef public APIConnection _connection 30 | cdef public bint _debug_enabled 31 | cdef public object _loop 32 | cdef public ConnectionParams _params 33 | cdef public str cached_name 34 | cdef public str log_name 35 | 36 | cpdef _set_log_name(self) 37 | 38 | cpdef APIConnection _get_connection(self) 39 | -------------------------------------------------------------------------------- /aioesphomeapi/client_base.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unidiomatic-typecheck 2 | from __future__ import annotations 3 | 4 | import asyncio 5 | from collections.abc import Coroutine 6 | import logging 7 | from typing import TYPE_CHECKING, Any, Callable 8 | 9 | from google.protobuf import message 10 | 11 | from ._frame_helper.base import APIFrameHelper # noqa: F401 12 | from ._frame_helper.noise import APINoiseFrameHelper # noqa: F401 13 | from ._frame_helper.plain_text import APIPlaintextFrameHelper # noqa: F401 14 | from .api_pb2 import ( # type: ignore 15 | BluetoothConnectionsFreeResponse, 16 | BluetoothDeviceConnectionResponse, 17 | BluetoothGATTErrorResponse, 18 | BluetoothGATTGetServicesDoneResponse, 19 | BluetoothGATTGetServicesResponse, 20 | BluetoothGATTNotifyDataResponse, 21 | BluetoothGATTNotifyResponse, 22 | BluetoothGATTReadResponse, 23 | BluetoothGATTWriteResponse, 24 | BluetoothLEAdvertisementResponse, 25 | BluetoothScannerStateResponse, 26 | CameraImageResponse, 27 | HomeassistantServiceResponse, 28 | SubscribeHomeAssistantStateResponse, 29 | ) 30 | from .connection import ConnectionParams 31 | from .core import APIConnectionError 32 | from .model import ( 33 | APIVersion, 34 | BluetoothLEAdvertisement, 35 | BluetoothScannerStateResponse as BluetoothScannerStateResponseModel, 36 | CameraState, 37 | EntityState, 38 | HomeassistantServiceCall, 39 | ) 40 | from .model_conversions import SUBSCRIBE_STATES_RESPONSE_TYPES 41 | from .util import build_log_name, create_eager_task 42 | from .zeroconf import ZeroconfInstanceType, ZeroconfManager 43 | 44 | if TYPE_CHECKING: 45 | from .connection import APIConnection 46 | 47 | _LOGGER = logging.getLogger(__name__) 48 | 49 | # We send a ping every 20 seconds, and the timeout ratio is 4.5x the 50 | # ping interval. This means that if we don't receive a ping for 90.0 51 | # seconds, we'll consider the connection dead and reconnect. 52 | # 53 | # This was chosen because the 20s is around the expected time for a 54 | # device to reboot and reconnect to wifi, and 90 seconds is the absolute 55 | # maximum time a device can take to respond when its behind + the WiFi 56 | # connection is poor. 57 | KEEP_ALIVE_FREQUENCY = 20.0 58 | 59 | 60 | def on_state_msg( 61 | on_state: Callable[[EntityState], None], 62 | image_stream: dict[int, list[bytes]], 63 | msg: message.Message, 64 | ) -> None: 65 | """Handle a state message.""" 66 | msg_type = type(msg) 67 | if (cls := SUBSCRIBE_STATES_RESPONSE_TYPES.get(msg_type)) is not None: 68 | on_state(cls.from_pb(msg)) 69 | elif msg_type is CameraImageResponse: 70 | if TYPE_CHECKING: 71 | assert isinstance(msg, CameraImageResponse) 72 | msg_key = msg.key 73 | data_parts: list[bytes] | None = image_stream.get(msg_key) 74 | if not data_parts: 75 | data_parts = [] 76 | image_stream[msg_key] = data_parts 77 | 78 | data_parts.append(msg.data) 79 | if msg.done: 80 | # Return CameraState with the merged data 81 | image_data = b"".join(data_parts) 82 | del image_stream[msg_key] 83 | on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg] 84 | 85 | 86 | def on_home_assistant_service_response( 87 | on_service_call: Callable[[HomeassistantServiceCall], None], 88 | msg: HomeassistantServiceResponse, 89 | ) -> None: 90 | on_service_call(HomeassistantServiceCall.from_pb(msg)) 91 | 92 | 93 | def on_bluetooth_le_advertising_response( 94 | on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None], 95 | msg: BluetoothLEAdvertisementResponse, 96 | ) -> None: 97 | on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc] 98 | 99 | 100 | def on_bluetooth_connections_free_response( 101 | on_bluetooth_connections_free_update: Callable[[int, int, list[int]], None], 102 | msg: BluetoothConnectionsFreeResponse, 103 | ) -> None: 104 | on_bluetooth_connections_free_update(msg.free, msg.limit, list(msg.allocated)) 105 | 106 | 107 | def on_bluetooth_gatt_notify_data_response( 108 | address: int, 109 | handle: int, 110 | on_bluetooth_gatt_notify: Callable[[int, bytearray], None], 111 | msg: BluetoothGATTNotifyDataResponse, 112 | ) -> None: 113 | """Handle a BluetoothGATTNotifyDataResponse message.""" 114 | if address == msg.address and handle == msg.handle: 115 | try: 116 | on_bluetooth_gatt_notify(handle, bytearray(msg.data)) 117 | except Exception: 118 | _LOGGER.exception( 119 | "Unexpected error in Bluetooth GATT notify callback for address %s, handle %s", 120 | address, 121 | handle, 122 | ) 123 | 124 | 125 | def on_bluetooth_scanner_state_response( 126 | on_bluetooth_scanner_state: Callable[[BluetoothScannerStateResponseModel], None], 127 | msg: BluetoothScannerStateResponse, 128 | ) -> None: 129 | on_bluetooth_scanner_state(BluetoothScannerStateResponseModel.from_pb(msg)) 130 | 131 | 132 | def on_subscribe_home_assistant_state_response( 133 | on_state_sub: Callable[[str, str | None], None], 134 | on_state_request: Callable[[str, str | None], None] | None, 135 | msg: SubscribeHomeAssistantStateResponse, 136 | ) -> None: 137 | if on_state_request and msg.once: 138 | on_state_request(msg.entity_id, msg.attribute) 139 | else: 140 | on_state_sub(msg.entity_id, msg.attribute) 141 | 142 | 143 | def on_bluetooth_device_connection_response( 144 | connect_future: asyncio.Future[None], 145 | address: int, 146 | on_bluetooth_connection_state: Callable[[bool, int, int], None], 147 | msg: BluetoothDeviceConnectionResponse, 148 | ) -> None: 149 | """Handle a BluetoothDeviceConnectionResponse message.""" "" 150 | if address == msg.address: 151 | on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error) 152 | # Resolve on ANY connection state since we do not want 153 | # to wait the whole timeout if the device disconnects 154 | # or we get an error. 155 | if not connect_future.done(): 156 | connect_future.set_result(None) 157 | 158 | 159 | def on_bluetooth_handle_message( 160 | address: int, 161 | handle: int, 162 | msg: ( 163 | BluetoothGATTErrorResponse 164 | | BluetoothGATTNotifyResponse 165 | | BluetoothGATTReadResponse 166 | | BluetoothGATTWriteResponse 167 | | BluetoothDeviceConnectionResponse 168 | ), 169 | ) -> bool: 170 | """Filter a Bluetooth message for an address and handle.""" 171 | if type(msg) is BluetoothDeviceConnectionResponse: 172 | return bool(msg.address == address) 173 | return bool(msg.address == address and msg.handle == handle) 174 | 175 | 176 | def on_bluetooth_message_types( 177 | address: int, 178 | msg_types: tuple[type[message.Message], ...], 179 | msg: ( 180 | BluetoothGATTErrorResponse 181 | | BluetoothGATTNotifyResponse 182 | | BluetoothGATTReadResponse 183 | | BluetoothGATTWriteResponse 184 | | BluetoothDeviceConnectionResponse 185 | | BluetoothGATTGetServicesResponse 186 | | BluetoothGATTGetServicesDoneResponse 187 | | BluetoothGATTErrorResponse 188 | ), 189 | ) -> bool: 190 | """Filter Bluetooth messages of a specific type and address.""" 191 | return type(msg) in msg_types and bool(msg.address == address) 192 | 193 | 194 | str_ = str 195 | 196 | 197 | def _stringify_or_none(value: str_ | None) -> str | None: 198 | """Convert a string like object to a str or None. 199 | 200 | The noise_psk is sometimes passed into 201 | the client as an Estr, but we want to pass it 202 | to the API as a string or None. 203 | """ 204 | return None if value is None else str(value) 205 | 206 | 207 | class APIClientBase: 208 | """Base client for ESPHome API clients.""" 209 | 210 | __slots__ = ( 211 | "_background_tasks", 212 | "_connection", 213 | "_debug_enabled", 214 | "_loop", 215 | "_params", 216 | "cached_name", 217 | "log_name", 218 | ) 219 | 220 | def __init__( 221 | self, 222 | address: str_, # allow subclass str 223 | port: int, 224 | password: str_ | None, 225 | *, 226 | client_info: str_ = "aioesphomeapi", 227 | keepalive: float = KEEP_ALIVE_FREQUENCY, 228 | zeroconf_instance: ZeroconfInstanceType | None = None, 229 | noise_psk: str_ | None = None, 230 | expected_name: str_ | None = None, 231 | addresses: list[str_] | None = None, 232 | expected_mac: str_ | None = None, 233 | ) -> None: 234 | """Create a client, this object is shared across sessions. 235 | 236 | :param address: The address to connect to; for example an IP address 237 | or .local name for mDNS lookup. 238 | :param port: The port to connect to 239 | :param password: Optional password to send to the device for authentication 240 | :param client_info: User Agent string to send. 241 | :param keepalive: The keepalive time in seconds (ping interval) for detecting stale connections. 242 | Every keepalive seconds a ping is sent, if no pong is received the connection is closed. 243 | :param zeroconf_instance: Pass a zeroconf instance to use if an mDNS lookup is necessary. 244 | :param noise_psk: Encryption preshared key for noise transport encrypted sessions. 245 | :param expected_name: Require the devices name to match the given expected name. 246 | Can be used to prevent accidentally connecting to a different device if 247 | IP passed as address but DHCP reassigned IP. 248 | :param addresses: Optional list of IP addresses to connect to which takes 249 | precedence over the address parameter. This is most commonly used when 250 | the device has dual stack IPv4 and IPv6 addresses and you do not know 251 | which one to connect to. 252 | :param expected_mac: Optional MAC address to check against the device. 253 | The format should be lower case without : or - separators. 254 | Example: 00:aa:22:33:44:55 -> 00aa22334455 255 | """ 256 | self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) 257 | self._params = ConnectionParams( 258 | addresses=addresses if addresses else [str(address)], 259 | port=port, 260 | password=password, 261 | client_info=client_info, 262 | keepalive=keepalive, 263 | zeroconf_manager=ZeroconfManager(zeroconf_instance), 264 | # treat empty '' psk string as missing (like password) 265 | noise_psk=_stringify_or_none(noise_psk) or None, 266 | expected_name=_stringify_or_none(expected_name) or None, 267 | expected_mac=_stringify_or_none(expected_mac) or None, 268 | ) 269 | self._connection: APIConnection | None = None 270 | self.cached_name: str | None = None 271 | self._background_tasks: set[asyncio.Task[Any]] = set() 272 | self._loop = asyncio.get_running_loop() 273 | self._set_log_name() 274 | 275 | def set_debug(self, enabled: bool) -> None: 276 | """Enable debug logging.""" 277 | self._debug_enabled = enabled 278 | if self._connection is not None: 279 | self._connection.set_debug(enabled) 280 | 281 | @property 282 | def zeroconf_manager(self) -> ZeroconfManager: 283 | return self._params.zeroconf_manager 284 | 285 | @property 286 | def expected_name(self) -> str | None: 287 | return self._params.expected_name 288 | 289 | @expected_name.setter 290 | def expected_name(self, value: str | None) -> None: 291 | self._params.expected_name = value 292 | 293 | @property 294 | def address(self) -> str: 295 | return self._params.addresses[0] 296 | 297 | @property 298 | def api_version(self) -> APIVersion | None: 299 | if self._connection is None: 300 | return None 301 | return self._connection.api_version 302 | 303 | def _set_log_name(self) -> None: 304 | """Set the log name of the device.""" 305 | connected_address: str | None = None 306 | if self._connection is not None and self._connection.connected_address: 307 | connected_address = self._connection.connected_address 308 | self.log_name = build_log_name( 309 | self.cached_name, 310 | self._params.addresses, 311 | connected_address, 312 | ) 313 | if self._connection is not None: 314 | self._connection.set_log_name(self.log_name) 315 | 316 | def _set_name_from_device(self, name: str_) -> None: 317 | """Set the name from a DeviceInfo message.""" 318 | self.cached_name = str(name) # May be Estr from esphome 319 | self._set_log_name() 320 | 321 | def set_cached_name_if_unset(self, name: str_) -> None: 322 | """Set the cached name of the device if not set.""" 323 | if not self.cached_name: 324 | self._set_name_from_device(name) 325 | 326 | def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None: 327 | """Create a background task and add it to the background tasks set.""" 328 | task = create_eager_task(coro) 329 | self._background_tasks.add(task) 330 | task.add_done_callback(self._background_tasks.discard) 331 | 332 | def _get_connection(self) -> APIConnection: 333 | if self._connection is None: 334 | raise APIConnectionError(f"Not connected to {self.log_name}!") 335 | if not self._connection.is_connected: 336 | raise APIConnectionError( 337 | f"Authenticated connection not ready yet for {self.log_name}; " 338 | f"current state is {self._connection.connection_state}!" 339 | ) 340 | return self._connection 341 | -------------------------------------------------------------------------------- /aioesphomeapi/connection.pxd: -------------------------------------------------------------------------------- 1 | import cython 2 | 3 | from ._frame_helper.base cimport APIFrameHelper 4 | 5 | 6 | cdef dict MESSAGE_TYPE_TO_PROTO 7 | cdef dict PROTO_TO_MESSAGE_TYPE 8 | 9 | cdef set OPEN_STATES 10 | 11 | cdef float KEEP_ALIVE_TIMEOUT_RATIO 12 | cdef object HANDSHAKE_TIMEOUT 13 | 14 | cdef bint TYPE_CHECKING 15 | cdef bint _WIN32 16 | 17 | cdef object WRITE_EXCEPTIONS 18 | 19 | cdef object DISCONNECT_REQUEST_MESSAGE 20 | cdef tuple DISCONNECT_RESPONSE_MESSAGES 21 | cdef tuple PING_REQUEST_MESSAGES 22 | cdef tuple PING_RESPONSE_MESSAGES 23 | cdef object NO_PASSWORD_CONNECT_REQUEST 24 | 25 | cdef object asyncio_timeout 26 | cdef object CancelledError 27 | cdef object asyncio_TimeoutError 28 | 29 | cdef object ConnectRequest, ConnectResponse 30 | cdef object DisconnectRequest 31 | cdef object PingRequest 32 | cdef object GetTimeRequest, GetTimeResponse 33 | cdef object HelloRequest, HelloResponse 34 | 35 | cdef object APIVersion 36 | 37 | cdef object partial 38 | 39 | cdef object hr 40 | 41 | cdef object CONNECT_AND_SETUP_TIMEOUT, CONNECT_REQUEST_TIMEOUT 42 | 43 | cdef object APIConnectionError 44 | cdef object BadNameAPIError 45 | cdef object HandshakeAPIError 46 | cdef object PingFailedAPIError 47 | cdef object ReadFailedAPIError 48 | cdef object TimeoutAPIError 49 | cdef object SocketAPIError 50 | cdef object InvalidAuthAPIError 51 | cdef object SocketClosedAPIError 52 | 53 | cdef object astuple 54 | 55 | cdef object CONNECTION_STATE_INITIALIZED 56 | cdef object CONNECTION_STATE_HOST_RESOLVED 57 | cdef object CONNECTION_STATE_SOCKET_OPENED 58 | cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE 59 | cdef object CONNECTION_STATE_CONNECTED 60 | cdef object CONNECTION_STATE_CLOSED 61 | 62 | cdef object make_hello_request 63 | 64 | cpdef void handle_timeout(object fut) 65 | cpdef void handle_complex_message( 66 | object fut, 67 | list responses, 68 | object do_append, 69 | object do_stop, 70 | object resp, 71 | ) 72 | 73 | cdef object _handle_timeout 74 | cdef object _handle_complex_message 75 | 76 | cdef tuple MESSAGE_NUMBER_TO_PROTO 77 | 78 | 79 | @cython.dataclasses.dataclass 80 | cdef class ConnectionParams: 81 | 82 | cdef public list addresses 83 | cdef public object port 84 | cdef public object password 85 | cdef public object client_info 86 | cdef public object keepalive 87 | cdef public object zeroconf_manager 88 | cdef public object noise_psk 89 | cdef public object expected_name 90 | cdef public object expected_mac 91 | 92 | 93 | cdef class APIConnection: 94 | 95 | cdef ConnectionParams _params 96 | cdef public object on_stop 97 | cdef public object _socket 98 | cdef public APIFrameHelper _frame_helper 99 | cdef public object api_version 100 | cdef public object connection_state 101 | cdef public dict _message_handlers 102 | cdef public str log_name 103 | cdef set _read_exception_futures 104 | cdef object _ping_timer 105 | cdef object _pong_timer 106 | cdef float _keep_alive_interval 107 | cdef float _keep_alive_timeout 108 | cdef object _resolve_host_future 109 | cdef object _start_connect_future 110 | cdef object _finish_connect_future 111 | cdef public Exception _fatal_exception 112 | cdef bint _expected_disconnect 113 | cdef object _loop 114 | cdef bint _send_pending_ping 115 | cdef public bint is_connected 116 | cdef bint _handshake_complete 117 | cdef bint _debug_enabled 118 | cdef public str received_name 119 | cdef public str connected_address 120 | cdef list _addrs_info 121 | 122 | cpdef void send_message(self, object msg) except * 123 | 124 | @cython.locals(msg_type=tuple) 125 | cdef void send_messages(self, tuple messages) except * 126 | 127 | @cython.locals(handlers=set, handlers_copy=set, klass_merge=tuple) 128 | cpdef void process_packet( 129 | self, 130 | unsigned int msg_type_proto, 131 | object data 132 | ) except * 133 | 134 | cdef void _async_cancel_pong_timer(self) except * 135 | 136 | cdef void _async_schedule_keep_alive(self, object now) except * 137 | 138 | cdef void _cleanup(self) except * 139 | 140 | cpdef set_log_name(self, str name) 141 | 142 | cdef _make_connect_request(self) 143 | 144 | cdef void _process_hello_resp(self, object resp) except * 145 | 146 | cdef void _process_login_response(self, object hello_response) except * 147 | 148 | cdef void _set_connection_state(self, object state) except * 149 | 150 | cpdef void report_fatal_error(self, Exception err) except * 151 | 152 | @cython.locals(handlers=set) 153 | cdef void _add_message_callback_without_remove( 154 | self, 155 | object on_message, 156 | tuple msg_types 157 | ) except * 158 | 159 | cpdef add_message_callback(self, object on_message, tuple msg_types) 160 | 161 | @cython.locals(handlers=set) 162 | cpdef void _remove_message_callback( 163 | self, 164 | object on_message, 165 | tuple msg_types 166 | ) except * 167 | 168 | cpdef void _handle_disconnect_request_internal(self, object msg) except * 169 | 170 | cpdef void _handle_ping_request_internal(self, object msg) except * 171 | 172 | cpdef void _handle_get_time_request_internal(self, object msg) except * 173 | 174 | cdef void _set_fatal_exception_if_unset(self, Exception err) except * 175 | 176 | cdef void _register_internal_message_handlers(self) except * 177 | 178 | cdef void _increase_recv_buffer_size(self) except * 179 | 180 | cdef void _set_start_connect_future(self) except * 181 | 182 | cdef void _set_finish_connect_future(self) except * 183 | -------------------------------------------------------------------------------- /aioesphomeapi/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import re 5 | 6 | from aioesphomeapi.model import BluetoothGATTError 7 | 8 | from .api_pb2 import ( # type: ignore 9 | AlarmControlPanelCommandRequest, 10 | AlarmControlPanelStateResponse, 11 | BinarySensorStateResponse, 12 | BluetoothConnectionsFreeResponse, 13 | BluetoothDeviceClearCacheResponse, 14 | BluetoothDeviceConnectionResponse, 15 | BluetoothDevicePairingResponse, 16 | BluetoothDeviceRequest, 17 | BluetoothDeviceUnpairingResponse, 18 | BluetoothGATTErrorResponse, 19 | BluetoothGATTGetServicesDoneResponse, 20 | BluetoothGATTGetServicesRequest, 21 | BluetoothGATTGetServicesResponse, 22 | BluetoothGATTNotifyDataResponse, 23 | BluetoothGATTNotifyRequest, 24 | BluetoothGATTNotifyResponse, 25 | BluetoothGATTReadDescriptorRequest, 26 | BluetoothGATTReadRequest, 27 | BluetoothGATTReadResponse, 28 | BluetoothGATTWriteDescriptorRequest, 29 | BluetoothGATTWriteRequest, 30 | BluetoothGATTWriteResponse, 31 | BluetoothLEAdvertisementResponse, 32 | BluetoothLERawAdvertisementsResponse, 33 | BluetoothScannerSetModeRequest, 34 | BluetoothScannerStateResponse, 35 | ButtonCommandRequest, 36 | CameraImageRequest, 37 | CameraImageResponse, 38 | ClimateCommandRequest, 39 | ClimateStateResponse, 40 | ConnectRequest, 41 | ConnectResponse, 42 | CoverCommandRequest, 43 | CoverStateResponse, 44 | DateCommandRequest, 45 | DateStateResponse, 46 | DateTimeCommandRequest, 47 | DateTimeStateResponse, 48 | DeviceInfoRequest, 49 | DeviceInfoResponse, 50 | DisconnectRequest, 51 | DisconnectResponse, 52 | EventResponse, 53 | ExecuteServiceRequest, 54 | FanCommandRequest, 55 | FanStateResponse, 56 | GetTimeRequest, 57 | GetTimeResponse, 58 | HelloRequest, 59 | HelloResponse, 60 | HomeassistantServiceResponse, 61 | HomeAssistantStateResponse, 62 | LightCommandRequest, 63 | LightStateResponse, 64 | ListEntitiesAlarmControlPanelResponse, 65 | ListEntitiesBinarySensorResponse, 66 | ListEntitiesButtonResponse, 67 | ListEntitiesCameraResponse, 68 | ListEntitiesClimateResponse, 69 | ListEntitiesCoverResponse, 70 | ListEntitiesDateResponse, 71 | ListEntitiesDateTimeResponse, 72 | ListEntitiesDoneResponse, 73 | ListEntitiesEventResponse, 74 | ListEntitiesFanResponse, 75 | ListEntitiesLightResponse, 76 | ListEntitiesLockResponse, 77 | ListEntitiesMediaPlayerResponse, 78 | ListEntitiesNumberResponse, 79 | ListEntitiesRequest, 80 | ListEntitiesSelectResponse, 81 | ListEntitiesSensorResponse, 82 | ListEntitiesServicesResponse, 83 | ListEntitiesSirenResponse, 84 | ListEntitiesSwitchResponse, 85 | ListEntitiesTextResponse, 86 | ListEntitiesTextSensorResponse, 87 | ListEntitiesTimeResponse, 88 | ListEntitiesUpdateResponse, 89 | ListEntitiesValveResponse, 90 | LockCommandRequest, 91 | LockStateResponse, 92 | MediaPlayerCommandRequest, 93 | MediaPlayerStateResponse, 94 | NoiseEncryptionSetKeyRequest, 95 | NoiseEncryptionSetKeyResponse, 96 | NumberCommandRequest, 97 | NumberStateResponse, 98 | PingRequest, 99 | PingResponse, 100 | SelectCommandRequest, 101 | SelectStateResponse, 102 | SensorStateResponse, 103 | SirenCommandRequest, 104 | SirenStateResponse, 105 | SubscribeBluetoothConnectionsFreeRequest, 106 | SubscribeBluetoothLEAdvertisementsRequest, 107 | SubscribeHomeassistantServicesRequest, 108 | SubscribeHomeAssistantStateResponse, 109 | SubscribeHomeAssistantStatesRequest, 110 | SubscribeLogsRequest, 111 | SubscribeLogsResponse, 112 | SubscribeStatesRequest, 113 | SubscribeVoiceAssistantRequest, 114 | SwitchCommandRequest, 115 | SwitchStateResponse, 116 | TextCommandRequest, 117 | TextSensorStateResponse, 118 | TextStateResponse, 119 | TimeCommandRequest, 120 | TimeStateResponse, 121 | UnsubscribeBluetoothLEAdvertisementsRequest, 122 | UpdateCommandRequest, 123 | UpdateStateResponse, 124 | ValveCommandRequest, 125 | ValveStateResponse, 126 | VoiceAssistantAnnounceFinished, 127 | VoiceAssistantAnnounceRequest, 128 | VoiceAssistantAudio, 129 | VoiceAssistantConfigurationRequest, 130 | VoiceAssistantConfigurationResponse, 131 | VoiceAssistantEventResponse, 132 | VoiceAssistantRequest, 133 | VoiceAssistantResponse, 134 | VoiceAssistantSetConfiguration, 135 | VoiceAssistantTimerEventResponse, 136 | ) 137 | 138 | TWO_CHAR = re.compile(r".{2}") 139 | 140 | # Taken from esp_gatt_status_t in esp_gatt_defs.h 141 | ESPHOME_GATT_ERRORS = { 142 | -1: "Not connected", # Custom ESPHome error 143 | 1: "Invalid handle", 144 | 2: "Read not permitted", 145 | 3: "Write not permitted", 146 | 4: "Invalid PDU", 147 | 5: "Insufficient authentication", 148 | 6: "Request not supported", 149 | 7: "Invalid offset", 150 | 8: "Insufficient authorization", 151 | 9: "Prepare queue full", 152 | 10: "Attribute not found", 153 | 11: "Attribute not long", 154 | 12: "Insufficient key size", 155 | 13: "Invalid attribute length", 156 | 14: "Unlikely error", 157 | 15: "Insufficient encryption", 158 | 16: "Unsupported group type", 159 | 17: "Insufficient resources", 160 | 128: "Application error", 161 | 129: "Internal error", 162 | 130: "Wrong state", 163 | 131: "Database full", 164 | 132: "Busy", 165 | 133: "Error", 166 | 134: "Command started", 167 | 135: "Illegal parameter", 168 | 136: "Pending", 169 | 137: "Auth fail", 170 | 138: "More", 171 | 139: "Invalid configuration", 172 | 140: "Service started", 173 | 141: "Encrypted no mitm", 174 | 142: "Not encrypted", 175 | 143: "Congested", 176 | 144: "Duplicate registration", 177 | 145: "Already open", 178 | 146: "Cancel", 179 | 224: "Stack RSP", 180 | 225: "App RSP", 181 | 239: "Unknown error", 182 | 253: "CCC config error", 183 | 254: "Procedure already in progress", 184 | 255: "Out of range", 185 | } 186 | 187 | 188 | class APIConnectionError(Exception): 189 | pass 190 | 191 | 192 | class APIConnectionCancelledError(APIConnectionError): 193 | pass 194 | 195 | 196 | class InvalidAuthAPIError(APIConnectionError): 197 | pass 198 | 199 | 200 | class ResolveAPIError(APIConnectionError): 201 | """Raised when a resolve error occurs.""" 202 | 203 | 204 | class ResolveTimeoutAPIError(ResolveAPIError, asyncio.TimeoutError): 205 | """Raised when a resolve timeout occurs.""" 206 | 207 | 208 | class ProtocolAPIError(APIConnectionError): 209 | pass 210 | 211 | 212 | class RequiresEncryptionAPIError(ProtocolAPIError): 213 | pass 214 | 215 | 216 | class SocketAPIError(APIConnectionError): 217 | pass 218 | 219 | 220 | class SocketClosedAPIError(SocketAPIError): 221 | pass 222 | 223 | 224 | class HandshakeAPIError(APIConnectionError): 225 | pass 226 | 227 | 228 | class ConnectionNotEstablishedAPIError(APIConnectionError): 229 | pass 230 | 231 | 232 | class BadNameAPIError(APIConnectionError): 233 | """Raised when a name received from the remote but does not much the expected name.""" 234 | 235 | def __init__(self, msg: str, received_name: str) -> None: 236 | super().__init__(f"{msg}: received_name={received_name}") 237 | self.received_name = received_name 238 | 239 | 240 | class BadMACAddressAPIError(APIConnectionError): 241 | """Raised when a MAC address received from the remote but does not much the expected MAC address.""" 242 | 243 | def __init__(self, msg: str, received_name: str, received_mac: str) -> None: 244 | super().__init__( 245 | f"{msg}: received_name={received_name}, received_mac={received_mac}" 246 | ) 247 | self.received_name = received_name 248 | self.received_mac = received_mac 249 | 250 | 251 | class InvalidEncryptionKeyAPIError(HandshakeAPIError): 252 | """Raised when the encryption key is invalid.""" 253 | 254 | def __init__( 255 | self, 256 | msg: str | None = None, 257 | received_name: str | None = None, 258 | received_mac: str | None = None, 259 | ) -> None: 260 | super().__init__( 261 | f"{msg}: received_name={received_name}, received_mac={received_mac}" 262 | ) 263 | self.received_name = received_name 264 | self.received_mac = received_mac 265 | 266 | 267 | class EncryptionErrorAPIError(InvalidEncryptionKeyAPIError): 268 | """Raised when an encryption error occurs after handshake.""" 269 | 270 | 271 | class EncryptionHelloAPIError(HandshakeAPIError): 272 | """Raised when an encryption error occurs during hello.""" 273 | 274 | 275 | class EncryptionPlaintextAPIError(HandshakeAPIError): 276 | """Raised when the ESP is using plaintext during noise handshake.""" 277 | 278 | 279 | class PingFailedAPIError(APIConnectionError): 280 | """Raised when a ping fails.""" 281 | 282 | 283 | class TimeoutAPIError(APIConnectionError): 284 | """Raised when a timeout occurs.""" 285 | 286 | 287 | class ReadFailedAPIError(APIConnectionError): 288 | """Raised when a read fails.""" 289 | 290 | 291 | class UnhandledAPIConnectionError(APIConnectionError): 292 | """Raised when an unhandled error occurs.""" 293 | 294 | 295 | class BluetoothConnectionDroppedError(APIConnectionError): 296 | """Raised when a Bluetooth connection is dropped.""" 297 | 298 | 299 | def to_human_readable_address(address: int) -> str: 300 | """Convert a MAC address to a human readable format.""" 301 | return ":".join(TWO_CHAR.findall(f"{address:012X}")) 302 | 303 | 304 | def to_human_readable_gatt_error(error: int) -> str: 305 | """Convert a GATT error to a human readable format.""" 306 | return ESPHOME_GATT_ERRORS.get(error, "Unknown error") 307 | 308 | 309 | class BluetoothGATTAPIError(APIConnectionError): 310 | def __init__(self, error: BluetoothGATTError) -> None: 311 | super().__init__( 312 | f"Bluetooth GATT Error " 313 | f"address={to_human_readable_address(error.address)} " 314 | f"handle={error.handle} " 315 | f"error={error.error} " 316 | f"description={to_human_readable_gatt_error(error.error)}" 317 | ) 318 | self.error = error 319 | 320 | 321 | MESSAGE_TYPE_TO_PROTO = { 322 | 1: HelloRequest, 323 | 2: HelloResponse, 324 | 3: ConnectRequest, 325 | 4: ConnectResponse, 326 | 5: DisconnectRequest, 327 | 6: DisconnectResponse, 328 | 7: PingRequest, 329 | 8: PingResponse, 330 | 9: DeviceInfoRequest, 331 | 10: DeviceInfoResponse, 332 | 11: ListEntitiesRequest, 333 | 12: ListEntitiesBinarySensorResponse, 334 | 13: ListEntitiesCoverResponse, 335 | 14: ListEntitiesFanResponse, 336 | 15: ListEntitiesLightResponse, 337 | 16: ListEntitiesSensorResponse, 338 | 17: ListEntitiesSwitchResponse, 339 | 18: ListEntitiesTextSensorResponse, 340 | 19: ListEntitiesDoneResponse, 341 | 20: SubscribeStatesRequest, 342 | 21: BinarySensorStateResponse, 343 | 22: CoverStateResponse, 344 | 23: FanStateResponse, 345 | 24: LightStateResponse, 346 | 25: SensorStateResponse, 347 | 26: SwitchStateResponse, 348 | 27: TextSensorStateResponse, 349 | 28: SubscribeLogsRequest, 350 | 29: SubscribeLogsResponse, 351 | 30: CoverCommandRequest, 352 | 31: FanCommandRequest, 353 | 32: LightCommandRequest, 354 | 33: SwitchCommandRequest, 355 | 34: SubscribeHomeassistantServicesRequest, 356 | 35: HomeassistantServiceResponse, 357 | 36: GetTimeRequest, 358 | 37: GetTimeResponse, 359 | 38: SubscribeHomeAssistantStatesRequest, 360 | 39: SubscribeHomeAssistantStateResponse, 361 | 40: HomeAssistantStateResponse, 362 | 41: ListEntitiesServicesResponse, 363 | 42: ExecuteServiceRequest, 364 | 43: ListEntitiesCameraResponse, 365 | 44: CameraImageResponse, 366 | 45: CameraImageRequest, 367 | 46: ListEntitiesClimateResponse, 368 | 47: ClimateStateResponse, 369 | 48: ClimateCommandRequest, 370 | 49: ListEntitiesNumberResponse, 371 | 50: NumberStateResponse, 372 | 51: NumberCommandRequest, 373 | 52: ListEntitiesSelectResponse, 374 | 53: SelectStateResponse, 375 | 54: SelectCommandRequest, 376 | 55: ListEntitiesSirenResponse, 377 | 56: SirenStateResponse, 378 | 57: SirenCommandRequest, 379 | 58: ListEntitiesLockResponse, 380 | 59: LockStateResponse, 381 | 60: LockCommandRequest, 382 | 61: ListEntitiesButtonResponse, 383 | 62: ButtonCommandRequest, 384 | 63: ListEntitiesMediaPlayerResponse, 385 | 64: MediaPlayerStateResponse, 386 | 65: MediaPlayerCommandRequest, 387 | 66: SubscribeBluetoothLEAdvertisementsRequest, 388 | 67: BluetoothLEAdvertisementResponse, 389 | 68: BluetoothDeviceRequest, 390 | 69: BluetoothDeviceConnectionResponse, 391 | 70: BluetoothGATTGetServicesRequest, 392 | 71: BluetoothGATTGetServicesResponse, 393 | 72: BluetoothGATTGetServicesDoneResponse, 394 | 73: BluetoothGATTReadRequest, 395 | 74: BluetoothGATTReadResponse, 396 | 75: BluetoothGATTWriteRequest, 397 | 76: BluetoothGATTReadDescriptorRequest, 398 | 77: BluetoothGATTWriteDescriptorRequest, 399 | 78: BluetoothGATTNotifyRequest, 400 | 79: BluetoothGATTNotifyDataResponse, 401 | 80: SubscribeBluetoothConnectionsFreeRequest, 402 | 81: BluetoothConnectionsFreeResponse, 403 | 82: BluetoothGATTErrorResponse, 404 | 83: BluetoothGATTWriteResponse, 405 | 84: BluetoothGATTNotifyResponse, 406 | 85: BluetoothDevicePairingResponse, 407 | 86: BluetoothDeviceUnpairingResponse, 408 | 87: UnsubscribeBluetoothLEAdvertisementsRequest, 409 | 88: BluetoothDeviceClearCacheResponse, 410 | 89: SubscribeVoiceAssistantRequest, 411 | 90: VoiceAssistantRequest, 412 | 91: VoiceAssistantResponse, 413 | 92: VoiceAssistantEventResponse, 414 | 93: BluetoothLERawAdvertisementsResponse, 415 | 94: ListEntitiesAlarmControlPanelResponse, 416 | 95: AlarmControlPanelStateResponse, 417 | 96: AlarmControlPanelCommandRequest, 418 | 97: ListEntitiesTextResponse, 419 | 98: TextStateResponse, 420 | 99: TextCommandRequest, 421 | 100: ListEntitiesDateResponse, 422 | 101: DateStateResponse, 423 | 102: DateCommandRequest, 424 | 103: ListEntitiesTimeResponse, 425 | 104: TimeStateResponse, 426 | 105: TimeCommandRequest, 427 | 106: VoiceAssistantAudio, 428 | 107: ListEntitiesEventResponse, 429 | 108: EventResponse, 430 | 109: ListEntitiesValveResponse, 431 | 110: ValveStateResponse, 432 | 111: ValveCommandRequest, 433 | 112: ListEntitiesDateTimeResponse, 434 | 113: DateTimeStateResponse, 435 | 114: DateTimeCommandRequest, 436 | 115: VoiceAssistantTimerEventResponse, 437 | 116: ListEntitiesUpdateResponse, 438 | 117: UpdateStateResponse, 439 | 118: UpdateCommandRequest, 440 | 119: VoiceAssistantAnnounceRequest, 441 | 120: VoiceAssistantAnnounceFinished, 442 | 121: VoiceAssistantConfigurationRequest, 443 | 122: VoiceAssistantConfigurationResponse, 444 | 123: VoiceAssistantSetConfiguration, 445 | 124: NoiseEncryptionSetKeyRequest, 446 | 125: NoiseEncryptionSetKeyResponse, 447 | 126: BluetoothScannerStateResponse, 448 | 127: BluetoothScannerSetModeRequest, 449 | } 450 | 451 | MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values()) 452 | -------------------------------------------------------------------------------- /aioesphomeapi/discover.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | 5 | # Helper script and aioesphomeapi to discover api devices 6 | import asyncio 7 | import contextlib 8 | import logging 9 | import sys 10 | 11 | from zeroconf import IPVersion, ServiceStateChange, Zeroconf 12 | from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf 13 | 14 | FORMAT = "{: <7}|{: <32}|{: <15}|{: <12}|{: <16}|{: <10}|{: <32}" 15 | COLUMN_NAMES = ("Status", "Name", "Address", "MAC", "Version", "Platform", "Board") 16 | UNKNOWN = "unknown" 17 | 18 | 19 | def decode_bytes_or_unknown(data: str | bytes | None) -> str: 20 | """Decode bytes or return unknown.""" 21 | if data is None: 22 | return UNKNOWN 23 | if isinstance(data, bytes): 24 | return data.decode() 25 | return data 26 | 27 | 28 | def async_service_update( 29 | zeroconf: Zeroconf, 30 | service_type: str, 31 | name: str, 32 | state_change: ServiceStateChange, 33 | ) -> None: 34 | """Service state changed.""" 35 | short_name = name.partition(".")[0] 36 | state = "OFFLINE" if state_change is ServiceStateChange.Removed else "ONLINE" 37 | info = AsyncServiceInfo(service_type, name) 38 | info.load_from_cache(zeroconf) 39 | properties = info.properties 40 | mac = decode_bytes_or_unknown(properties.get(b"mac")) 41 | version = decode_bytes_or_unknown(properties.get(b"version")) 42 | platform = decode_bytes_or_unknown(properties.get(b"platform")) 43 | board = decode_bytes_or_unknown(properties.get(b"board")) 44 | address = "" 45 | if addresses := info.ip_addresses_by_version(IPVersion.V4Only): 46 | address = str(addresses[0]) 47 | 48 | print(FORMAT.format(state, short_name, address, mac, version, platform, board)) 49 | 50 | 51 | async def main(argv: list[str]) -> None: 52 | parser = argparse.ArgumentParser("aioesphomeapi-discover") 53 | parser.add_argument("-v", "--verbose", action="store_true") 54 | args = parser.parse_args(argv[1:]) 55 | logging.basicConfig( 56 | format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s", 57 | level=logging.DEBUG if args.verbose else logging.INFO, 58 | datefmt="%Y-%m-%d %H:%M:%S", 59 | ) 60 | if args.verbose: 61 | logging.getLogger("zeroconf").setLevel(logging.DEBUG) 62 | 63 | aiozc = AsyncZeroconf() 64 | browser = AsyncServiceBrowser( 65 | aiozc.zeroconf, "_esphomelib._tcp.local.", handlers=[async_service_update] 66 | ) 67 | print(FORMAT.format(*COLUMN_NAMES)) 68 | print("-" * 120) 69 | 70 | try: 71 | await asyncio.Event().wait() 72 | finally: 73 | await browser.async_cancel() 74 | await aiozc.async_close() 75 | 76 | 77 | def cli_entry_point() -> None: 78 | """Run the CLI.""" 79 | with contextlib.suppress(KeyboardInterrupt): 80 | asyncio.run(main(sys.argv)) 81 | 82 | 83 | if __name__ == "__main__": 84 | cli_entry_point() 85 | sys.exit(0) 86 | -------------------------------------------------------------------------------- /aioesphomeapi/host_resolver.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from collections import defaultdict 5 | from collections.abc import Coroutine 6 | from contextlib import suppress 7 | from dataclasses import dataclass 8 | from ipaddress import IPv4Address, IPv6Address, ip_address 9 | import itertools 10 | import logging 11 | import socket 12 | from typing import TYPE_CHECKING, Any, cast 13 | 14 | from zeroconf import IPVersion 15 | from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf 16 | 17 | from .core import ResolveAPIError, ResolveTimeoutAPIError 18 | from .util import ( 19 | address_is_local, 20 | asyncio_timeout, 21 | create_eager_task, 22 | host_is_name_part, 23 | ) 24 | from .zeroconf import ZeroconfManager 25 | 26 | _LOGGER = logging.getLogger(__name__) 27 | 28 | 29 | SERVICE_TYPE = "_esphomelib._tcp.local." 30 | RESOLVE_TIMEOUT = 30.0 31 | 32 | 33 | @dataclass(frozen=True) 34 | class Sockaddr: 35 | """Base socket address.""" 36 | 37 | address: str 38 | port: int 39 | 40 | 41 | @dataclass(frozen=True) 42 | class IPv4Sockaddr(Sockaddr): 43 | """IPv4 socket address.""" 44 | 45 | 46 | @dataclass(frozen=True) 47 | class IPv6Sockaddr(Sockaddr): 48 | """IPv6 socket address.""" 49 | 50 | flowinfo: int 51 | scope_id: int 52 | 53 | 54 | @dataclass(frozen=True) 55 | class AddrInfo: 56 | family: int 57 | type: int 58 | proto: int 59 | sockaddr: IPv4Sockaddr | IPv6Sockaddr 60 | 61 | 62 | async def _async_zeroconf_get_service_info( 63 | aiozc: AsyncZeroconf, 64 | short_host: str, 65 | timeout: float, 66 | ) -> AsyncServiceInfo: 67 | info = _make_service_info_for_short_host(short_host) 68 | try: 69 | await info.async_request(aiozc.zeroconf, int(timeout * 1000)) 70 | except Exception as exc: 71 | raise ResolveAPIError( 72 | f"Error resolving mDNS {short_host} via mDNS: {exc}" 73 | ) from exc 74 | return info 75 | 76 | 77 | def _scope_id_to_int(value: str | None) -> int: 78 | """Convert a scope id to int if possible.""" 79 | if value is None: 80 | return 0 81 | try: 82 | return int(value) 83 | except ValueError: 84 | return 0 85 | 86 | 87 | def _make_service_info_for_short_host(host: str) -> AsyncServiceInfo: 88 | """Make service info for an ESPHome host.""" 89 | service_name = f"{host}.{SERVICE_TYPE}" 90 | server = f"{host}.local." 91 | return AsyncServiceInfo(SERVICE_TYPE, service_name, server=server) 92 | 93 | 94 | async def _async_resolve_short_host_zeroconf( 95 | aiozc: AsyncZeroconf, 96 | short_host: str, 97 | port: int, 98 | *, 99 | timeout: float = 3.0, 100 | ) -> list[AddrInfo]: 101 | _LOGGER.debug("Resolving host %s via mDNS", short_host) 102 | service_info = await _async_zeroconf_get_service_info(aiozc, short_host, timeout) 103 | return service_info_to_addr_info(service_info, port) 104 | 105 | 106 | def service_info_to_addr_info(info: AsyncServiceInfo, port: int) -> list[AddrInfo]: 107 | return [ 108 | _async_ip_address_to_addrinfo(ip, port) 109 | for version in (IPVersion.V6Only, IPVersion.V4Only) 110 | for ip in info.ip_addresses_by_version(version) 111 | ] 112 | 113 | 114 | async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo]: 115 | loop = asyncio.get_running_loop() 116 | try: 117 | # Limit to TCP IP protocol and SOCK_STREAM 118 | res = await loop.getaddrinfo( 119 | host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP 120 | ) 121 | except OSError as err: 122 | raise ResolveAPIError(f"Error resolving {host} to IP address: {err}") 123 | 124 | addrs: list[AddrInfo] = [] 125 | for family, type_, proto, _, raw in res: 126 | sockaddr: IPv4Sockaddr | IPv6Sockaddr 127 | if family == socket.AF_INET: 128 | raw = cast(tuple[str, int], raw) 129 | address, port = raw 130 | sockaddr = IPv4Sockaddr(address=address, port=port) 131 | elif family == socket.AF_INET6: 132 | raw = cast(tuple[str, int, int, int], raw) 133 | address, port, flowinfo, scope_id = raw 134 | sockaddr = IPv6Sockaddr( 135 | address=address, port=port, flowinfo=flowinfo, scope_id=scope_id 136 | ) 137 | else: 138 | # Unknown family 139 | continue 140 | 141 | addrs.append( 142 | AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr) 143 | ) 144 | return addrs 145 | 146 | 147 | def _async_ip_address_to_addrinfo(ip: IPv4Address | IPv6Address, port: int) -> AddrInfo: 148 | """Convert an ipaddress to AddrInfo.""" 149 | is_ipv6 = ip.version == 6 150 | sockaddr: IPv6Sockaddr | IPv4Sockaddr 151 | if is_ipv6: 152 | if TYPE_CHECKING: 153 | assert isinstance(ip, IPv6Address) 154 | sockaddr = IPv6Sockaddr( 155 | address=str(ip).partition("%")[0], 156 | port=port, 157 | flowinfo=0, 158 | scope_id=_scope_id_to_int(ip.scope_id), 159 | ) 160 | else: 161 | sockaddr = IPv4Sockaddr( 162 | address=str(ip), 163 | port=port, 164 | ) 165 | 166 | return AddrInfo( 167 | family=socket.AF_INET6 if is_ipv6 else socket.AF_INET, 168 | type=socket.SOCK_STREAM, 169 | proto=socket.IPPROTO_TCP, 170 | sockaddr=sockaddr, 171 | ) 172 | 173 | 174 | def host_is_local_name(host: str) -> bool: 175 | """Check if the host is a local name.""" 176 | return host_is_name_part(host) or address_is_local(host) 177 | 178 | 179 | async def async_resolve_host( 180 | hosts: list[str], 181 | port: int, 182 | zeroconf_manager: ZeroconfManager | None = None, 183 | timeout: float = RESOLVE_TIMEOUT, 184 | ) -> list[AddrInfo]: 185 | """Resolve hosts in parallel. 186 | 187 | We will try to resolve the host in the following order: 188 | - If the host is an IP address, we will return that and skip 189 | trying to resolve it at all. 190 | 191 | - If the host is a local name, we will try to resolve it via mDNS 192 | - Otherwise, we will use getaddrinfo to resolve it as well 193 | 194 | Once we know which hosts to resolve and which methods, all 195 | resolution runs in parallel and we will return the first 196 | result we get for each host. 197 | """ 198 | manager: ZeroconfManager | None = None 199 | had_zeroconf_instance: bool = False 200 | resolve_results: defaultdict[str, list[AddrInfo]] = defaultdict(list) 201 | aiozc: AsyncZeroconf | None = None 202 | tried_to_create_zeroconf: bool = False 203 | exceptions: list[BaseException] = [] 204 | 205 | # First try to handle the cases where we do not need to 206 | # do any network calls at all. 207 | # - If the host is an IP address, we can just return that 208 | # - If we have a zeroconf manager and the host is in the cache 209 | # we can return that as well 210 | for host in hosts: 211 | # If its an IP address, we can convert it to an AddrInfo 212 | # and we are done with this host 213 | try: 214 | ip_addr_info = _async_ip_address_to_addrinfo(ip_address(host), port) 215 | except ValueError: 216 | pass 217 | else: 218 | if ip_addr_info: 219 | resolve_results[host].append(ip_addr_info) 220 | continue 221 | 222 | if not host_is_local_name(host): 223 | continue 224 | 225 | # If its a local name, we can try to fetch it from the zeroconf cache 226 | if not tried_to_create_zeroconf: 227 | tried_to_create_zeroconf = True 228 | manager = zeroconf_manager or ZeroconfManager() 229 | had_zeroconf_instance = manager.has_instance 230 | try: 231 | aiozc = manager.get_async_zeroconf() 232 | except Exception as original_exc: 233 | new_exc = ResolveAPIError( 234 | f"Cannot start mDNS sockets while resolving {host}: " 235 | f"{original_exc}, is this a docker container " 236 | "without host network mode? " 237 | ) 238 | new_exc.__cause__ = original_exc 239 | exceptions.append(new_exc) 240 | 241 | if aiozc: 242 | short_host = host.partition(".")[0] 243 | service_info = _make_service_info_for_short_host(short_host) 244 | if service_info.load_from_cache(aiozc.zeroconf) and ( 245 | addr_infos := service_info_to_addr_info(service_info, port) 246 | ): 247 | resolve_results[host].extend(addr_infos) 248 | 249 | try: 250 | if len(resolve_results) != len(hosts): 251 | # If we have not resolved all hosts yet, we need to do some network calls 252 | try: 253 | async with asyncio_timeout(timeout): 254 | await _async_resolve_host( 255 | hosts, port, resolve_results, exceptions, aiozc, timeout 256 | ) 257 | except asyncio.TimeoutError as err: 258 | raise ResolveTimeoutAPIError( 259 | f"Timeout while resolving IP address for {hosts}" 260 | ) from err 261 | finally: 262 | if manager and not had_zeroconf_instance: 263 | await asyncio.shield(create_eager_task(manager.async_close())) 264 | 265 | if addrs := list(itertools.chain.from_iterable(resolve_results.values())): 266 | return addrs 267 | 268 | if exceptions: 269 | raise ResolveAPIError(" ,".join([str(exc) for exc in exceptions])) 270 | raise ResolveAPIError(f"Could not resolve host {hosts} - got no results from OS") 271 | 272 | 273 | async def _async_resolve_host( 274 | hosts: list[str], 275 | port: int, 276 | resolve_results: defaultdict[str, list[AddrInfo]], 277 | exceptions: list[BaseException], 278 | aiozc: AsyncZeroconf | None, 279 | timeout: float, 280 | ) -> None: 281 | """Resolve hosts in parallel. 282 | 283 | As soon as we get a result for a host, we will cancel 284 | all other tasks trying to resolve that host. 285 | 286 | This function will resolve hosts in parallel using 287 | both mDNS and getaddrinfo. 288 | 289 | This function is also designed to be cancellable, so 290 | if we get cancelled, we will cancel all tasks, and 291 | clean up after ourselves. 292 | """ 293 | resolve_task_to_host: dict[asyncio.Task[list[AddrInfo]], str] = {} 294 | host_tasks: defaultdict[str, set[asyncio.Task[list[AddrInfo]]]] = defaultdict(set) 295 | 296 | try: 297 | for host in hosts: 298 | coros: list[Coroutine[Any, Any, list[AddrInfo]]] = [] 299 | if aiozc and host_is_local_name(host): 300 | short_host = host.partition(".")[0] 301 | coros.append( 302 | _async_resolve_short_host_zeroconf( 303 | aiozc, short_host, port, timeout=timeout 304 | ) 305 | ) 306 | 307 | coros.append(_async_resolve_host_getaddrinfo(host, port)) 308 | 309 | for coro in coros: 310 | task = create_eager_task(coro) 311 | if task.done(): 312 | if exc := task.exception(): 313 | exceptions.append(exc) 314 | else: 315 | resolve_results[host].extend(task.result()) 316 | else: 317 | resolve_task_to_host[task] = host 318 | host_tasks[host].add(task) 319 | 320 | while resolve_task_to_host: 321 | done, _ = await asyncio.wait( 322 | resolve_task_to_host, 323 | return_when=asyncio.FIRST_COMPLETED, 324 | ) 325 | finished_hosts: set[str] = set() 326 | for task in done: 327 | host = resolve_task_to_host.pop(task) 328 | host_tasks[host].discard(task) 329 | if exc := task.exception(): 330 | exceptions.append(exc) 331 | elif result := task.result(): 332 | resolve_results[host].extend(result) 333 | finished_hosts.add(host) 334 | 335 | # We got a result for a host, cancel 336 | # any other tasks trying to resolve 337 | # it as we are done with that host 338 | for host in finished_hosts: 339 | for task in host_tasks.pop(host, ()): 340 | resolve_task_to_host.pop(task, None) 341 | task.cancel() 342 | with suppress(asyncio.CancelledError): 343 | await task 344 | finally: 345 | # We likely get here if we get cancelled 346 | # because of a timeout 347 | for task in resolve_task_to_host: 348 | task.cancel() 349 | 350 | # Await all remaining tasks only after cancelling 351 | # them in case we get cancelled ourselves 352 | for task in resolve_task_to_host: 353 | with suppress(asyncio.CancelledError): 354 | await task 355 | -------------------------------------------------------------------------------- /aioesphomeapi/log_parser.py: -------------------------------------------------------------------------------- 1 | """Log parser for ESPHome log messages with ANSI color support.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Iterable 6 | import re 7 | 8 | # Pre-compiled regex for ANSI escape sequences 9 | ANSI_ESCAPE = re.compile( 10 | r"(?:\x1B[@-Z\\-_]|[\x80-\x9A\x9C-\x9F]|(?:\x1B\[|\x9B)[0-?]*[ -/]*[@-~])" 11 | ) 12 | 13 | # ANSI reset sequences 14 | ANSI_RESET_CODES = ("\033[0m", "\x1b[0m") 15 | ANSI_RESET = "\033[0m" 16 | 17 | 18 | def _extract_prefix_and_color(line: str, strip_ansi: bool) -> tuple[str, str, str]: 19 | """Extract ESPHome prefix and ANSI color code from line. 20 | 21 | Returns: 22 | Tuple of (prefix, color_code, line_without_color) 23 | """ 24 | color_code = "" 25 | line_no_color = line 26 | 27 | # Extract ANSI color code at the beginning if present 28 | if not strip_ansi and (color_match := ANSI_ESCAPE.match(line)): 29 | color_code = color_match.group(0) 30 | line_no_color = line[len(color_code) :] 31 | 32 | # Find the ESPHome prefix 33 | bracket_colon = line_no_color.find("]:") 34 | prefix = line_no_color[: bracket_colon + 2] if bracket_colon != -1 else "" 35 | 36 | return prefix, color_code, line_no_color 37 | 38 | 39 | def _needs_reset(line: str) -> bool: 40 | """Check if line needs ANSI reset code appended.""" 41 | return bool( 42 | line 43 | and not line.endswith(ANSI_RESET_CODES) 44 | and ("\033[" in line or "\x1b[" in line) 45 | ) 46 | 47 | 48 | def _format_continuation_line( 49 | timestamp: str, 50 | prefix: str, 51 | line: str, 52 | color_code: str = "", 53 | strip_ansi: bool = False, 54 | ) -> str: 55 | """Format a continuation line with prefix and optional color.""" 56 | line_content = f"{prefix} {line}" if prefix else line 57 | 58 | if color_code and not strip_ansi: 59 | reset = "" if line.endswith(ANSI_RESET_CODES) else ANSI_RESET 60 | return f"{timestamp}{color_code}{line_content}{reset}" 61 | 62 | return f"{timestamp}{line_content}" 63 | 64 | 65 | class LogParser: 66 | """Stateful parser for processing log messages one line at a time. 67 | 68 | This parser is designed for streaming input where log messages come 69 | line by line rather than in complete multi-line blocks. 70 | """ 71 | 72 | def __init__(self, strip_ansi_escapes: bool = False) -> None: 73 | """Initialize the parser. 74 | 75 | Args: 76 | strip_ansi_escapes: If True, remove all ANSI escape sequences from output 77 | """ 78 | self.strip_ansi_escapes = strip_ansi_escapes 79 | self._current_prefix = "" 80 | self._current_color_code = "" 81 | 82 | def parse_line(self, line: str, timestamp: str) -> str: 83 | """Parse a single line and return formatted output. 84 | 85 | Args: 86 | line: A single line of log text (without newline) 87 | timestamp: The timestamp string to prepend (e.g., "[08:00:00.000]") 88 | 89 | Returns: 90 | Formatted line ready to be printed. 91 | """ 92 | # Strip any trailing newline if present 93 | line = line.rstrip("\n\r") 94 | 95 | # Strip ANSI escapes if requested 96 | if self.strip_ansi_escapes: 97 | line = ANSI_ESCAPE.sub("", line) 98 | 99 | # Empty line handling 100 | if not line: 101 | return "" 102 | 103 | # Check if this is a new log entry or a continuation 104 | is_continuation = line[0].isspace() 105 | 106 | if not is_continuation: 107 | # This is a new log entry - update state 108 | self._current_prefix = "" 109 | self._current_color_code = "" 110 | 111 | # Extract prefix and color for potential multi-line messages 112 | if line and not line[0].isspace(): 113 | self._current_prefix, self._current_color_code, _ = ( 114 | _extract_prefix_and_color(line, self.strip_ansi_escapes) 115 | ) 116 | 117 | # Format the first line 118 | output = f"{timestamp}{line}" 119 | 120 | # Add reset if line has color but no reset at end 121 | if not self.strip_ansi_escapes and _needs_reset(line): 122 | output += ANSI_RESET 123 | 124 | return output 125 | 126 | # This is a continuation line 127 | if not line.strip(): 128 | return "" 129 | 130 | return _format_continuation_line( 131 | timestamp, 132 | self._current_prefix, 133 | line, 134 | self._current_color_code, 135 | self.strip_ansi_escapes, 136 | ) 137 | 138 | 139 | def parse_log_message( 140 | text: str, timestamp: str, *, strip_ansi_escapes: bool = False 141 | ) -> Iterable[str]: 142 | """Parse a log message and format it with timestamps and color preservation. 143 | 144 | Args: 145 | text: The log message text, potentially with ANSI codes and newlines 146 | timestamp: The timestamp string to prepend (e.g., "[08:00:00.000]") 147 | strip_ansi_escapes: If True, remove all ANSI escape sequences from output 148 | 149 | Returns: 150 | Iterable of formatted lines ready to be printed. 151 | For single-line logs, returns a tuple for efficiency. 152 | For multi-line logs, returns a list. 153 | """ 154 | # Strip ANSI escapes if requested 155 | if strip_ansi_escapes: 156 | text = ANSI_ESCAPE.sub("", text) 157 | 158 | # Fast path for single line (most common case) 159 | if "\n" not in text: 160 | return (f"{timestamp}{text}",) 161 | 162 | # Multi-line handling 163 | lines = text.split("\n") 164 | 165 | # Remove trailing empty line or ANSI reset codes 166 | if lines and (lines[-1] == "" or lines[-1] in ANSI_RESET_CODES): 167 | lines.pop() 168 | result: list[str] = [] 169 | 170 | # Process the first line 171 | first_line_output = f"{timestamp}{lines[0]}" 172 | 173 | # Check if first line has color but no reset at end (to prevent bleeding) 174 | if not strip_ansi_escapes and _needs_reset(lines[0]): 175 | first_line_output += ANSI_RESET 176 | 177 | result.append(first_line_output) 178 | 179 | # Extract prefix and color from the first line 180 | first_line = lines[0] 181 | prefix = "" 182 | color_code = "" 183 | 184 | # Extract prefix if first line doesn't start with space 185 | if first_line and not first_line[0].isspace(): 186 | prefix, color_code, _ = _extract_prefix_and_color( 187 | first_line, strip_ansi_escapes 188 | ) 189 | 190 | # Process subsequent lines 191 | for line in lines[1:]: 192 | if not line.strip(): # Only process non-empty lines 193 | # Empty line 194 | result.append("") 195 | continue 196 | if not line[0].isspace(): # If line starts with whitespace, it's a continuation 197 | # This is a new log entry within the same message 198 | result.append(f"{timestamp}{line}") 199 | continue 200 | # Apply timestamp, color, prefix, and the continuation line 201 | result.append( 202 | _format_continuation_line( 203 | timestamp, prefix, line, color_code, strip_ansi_escapes 204 | ) 205 | ) 206 | 207 | return result 208 | -------------------------------------------------------------------------------- /aioesphomeapi/log_reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # Helper script and aioesphomeapi to view logs from an esphome device 4 | import argparse 5 | import asyncio 6 | import contextlib 7 | from datetime import datetime 8 | import logging 9 | import sys 10 | 11 | from .api_pb2 import SubscribeLogsResponse # type: ignore 12 | from .client import APIClient 13 | from .log_parser import parse_log_message 14 | from .log_runner import async_run 15 | 16 | 17 | async def main(argv: list[str]) -> None: 18 | parser = argparse.ArgumentParser("aioesphomeapi-logs") 19 | parser.add_argument("--port", type=int, default=6053) 20 | parser.add_argument("--password", type=str) 21 | parser.add_argument("--noise-psk", type=str) 22 | parser.add_argument("-v", "--verbose", action="store_true") 23 | parser.add_argument("address") 24 | args = parser.parse_args(argv[1:]) 25 | 26 | logging.basicConfig( 27 | format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s", 28 | level=logging.DEBUG if args.verbose else logging.INFO, 29 | datefmt="%Y-%m-%d %H:%M:%S", 30 | ) 31 | 32 | cli = APIClient( 33 | args.address, 34 | args.port, 35 | args.password or "", 36 | noise_psk=args.noise_psk, 37 | keepalive=10, 38 | ) 39 | 40 | def on_log(msg: SubscribeLogsResponse) -> None: 41 | time_ = datetime.now() 42 | message: bytes = msg.message 43 | text = message.decode("utf8", "backslashreplace") 44 | nanoseconds = time_.microsecond // 1000 45 | timestamp = ( 46 | f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]" 47 | ) 48 | 49 | # Parse and print the log message 50 | for line in parse_log_message(text, timestamp): 51 | print(line) 52 | 53 | stop = await async_run(cli, on_log) 54 | try: 55 | await asyncio.Event().wait() 56 | finally: 57 | await stop() 58 | 59 | 60 | def cli_entry_point() -> None: 61 | """Run the CLI.""" 62 | with contextlib.suppress(KeyboardInterrupt): 63 | asyncio.run(main(sys.argv)) 64 | 65 | 66 | if __name__ == "__main__": 67 | cli_entry_point() 68 | sys.exit(0) 69 | -------------------------------------------------------------------------------- /aioesphomeapi/log_runner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Coroutine 4 | import logging 5 | from typing import Any, Callable 6 | 7 | from zeroconf.asyncio import AsyncZeroconf 8 | 9 | from .api_pb2 import SubscribeLogsResponse # type: ignore 10 | from .client import APIClient 11 | from .core import APIConnectionError 12 | from .model import LogLevel 13 | from .reconnect_logic import ReconnectLogic 14 | 15 | _LOGGER = logging.getLogger(__name__) 16 | 17 | 18 | async def async_run( 19 | cli: APIClient, 20 | on_log: Callable[[SubscribeLogsResponse], None], 21 | log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE, 22 | aio_zeroconf_instance: AsyncZeroconf | None = None, 23 | dump_config: bool = True, 24 | name: str | None = None, 25 | ) -> Callable[[], Coroutine[Any, Any, None]]: 26 | """Run logs until canceled. 27 | 28 | Returns a coroutine that can be awaited to stop the logs. 29 | """ 30 | dumped_config = not dump_config 31 | 32 | async def on_connect() -> None: 33 | """Handle a connection.""" 34 | nonlocal dumped_config 35 | try: 36 | cli.subscribe_logs( 37 | on_log, 38 | log_level=log_level, 39 | dump_config=not dumped_config, 40 | ) 41 | dumped_config = True 42 | except APIConnectionError: 43 | await cli.disconnect() 44 | 45 | async def on_disconnect( # pylint: disable=unused-argument 46 | expected_disconnect: bool, 47 | ) -> None: 48 | _LOGGER.warning("Disconnected from API") 49 | 50 | logic = ReconnectLogic( 51 | client=cli, 52 | on_connect=on_connect, 53 | on_disconnect=on_disconnect, 54 | zeroconf_instance=aio_zeroconf_instance, 55 | name=name, 56 | ) 57 | await logic.start() 58 | 59 | async def _stop() -> None: 60 | await logic.stop() 61 | await cli.disconnect() 62 | 63 | return _stop 64 | -------------------------------------------------------------------------------- /aioesphomeapi/model_conversions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from .api_pb2 import ( # type: ignore 6 | AlarmControlPanelStateResponse, 7 | BinarySensorStateResponse, 8 | ClimateStateResponse, 9 | CoverStateResponse, 10 | DateStateResponse, 11 | DateTimeStateResponse, 12 | EventResponse, 13 | FanStateResponse, 14 | LightStateResponse, 15 | ListEntitiesAlarmControlPanelResponse, 16 | ListEntitiesBinarySensorResponse, 17 | ListEntitiesButtonResponse, 18 | ListEntitiesCameraResponse, 19 | ListEntitiesClimateResponse, 20 | ListEntitiesCoverResponse, 21 | ListEntitiesDateResponse, 22 | ListEntitiesDateTimeResponse, 23 | ListEntitiesEventResponse, 24 | ListEntitiesFanResponse, 25 | ListEntitiesLightResponse, 26 | ListEntitiesLockResponse, 27 | ListEntitiesMediaPlayerResponse, 28 | ListEntitiesNumberResponse, 29 | ListEntitiesSelectResponse, 30 | ListEntitiesSensorResponse, 31 | ListEntitiesServicesResponse, 32 | ListEntitiesSirenResponse, 33 | ListEntitiesSwitchResponse, 34 | ListEntitiesTextResponse, 35 | ListEntitiesTextSensorResponse, 36 | ListEntitiesTimeResponse, 37 | ListEntitiesUpdateResponse, 38 | ListEntitiesValveResponse, 39 | LockStateResponse, 40 | MediaPlayerStateResponse, 41 | NumberStateResponse, 42 | SelectStateResponse, 43 | SensorStateResponse, 44 | SirenStateResponse, 45 | SwitchStateResponse, 46 | TextSensorStateResponse, 47 | TextStateResponse, 48 | TimeStateResponse, 49 | UpdateStateResponse, 50 | ValveStateResponse, 51 | ) 52 | from .model import ( 53 | AlarmControlPanelEntityState, 54 | AlarmControlPanelInfo, 55 | BinarySensorInfo, 56 | BinarySensorState, 57 | ButtonInfo, 58 | CameraInfo, 59 | ClimateInfo, 60 | ClimateState, 61 | CoverInfo, 62 | CoverState, 63 | DateInfo, 64 | DateState, 65 | DateTimeInfo, 66 | DateTimeState, 67 | EntityInfo, 68 | EntityState, 69 | Event, 70 | EventInfo, 71 | FanInfo, 72 | FanState, 73 | LightInfo, 74 | LightState, 75 | LockEntityState, 76 | LockInfo, 77 | MediaPlayerEntityState, 78 | MediaPlayerInfo, 79 | NumberInfo, 80 | NumberState, 81 | SelectInfo, 82 | SelectState, 83 | SensorInfo, 84 | SensorState, 85 | SirenInfo, 86 | SirenState, 87 | SwitchInfo, 88 | SwitchState, 89 | TextInfo, 90 | TextSensorInfo, 91 | TextSensorState, 92 | TextState, 93 | TimeInfo, 94 | TimeState, 95 | UpdateInfo, 96 | UpdateState, 97 | ValveInfo, 98 | ValveState, 99 | ) 100 | 101 | SUBSCRIBE_STATES_RESPONSE_TYPES: dict[Any, type[EntityState]] = { 102 | AlarmControlPanelStateResponse: AlarmControlPanelEntityState, 103 | BinarySensorStateResponse: BinarySensorState, 104 | ClimateStateResponse: ClimateState, 105 | CoverStateResponse: CoverState, 106 | DateStateResponse: DateState, 107 | DateTimeStateResponse: DateTimeState, 108 | EventResponse: Event, 109 | FanStateResponse: FanState, 110 | LightStateResponse: LightState, 111 | LockStateResponse: LockEntityState, 112 | MediaPlayerStateResponse: MediaPlayerEntityState, 113 | NumberStateResponse: NumberState, 114 | SelectStateResponse: SelectState, 115 | SensorStateResponse: SensorState, 116 | SirenStateResponse: SirenState, 117 | SwitchStateResponse: SwitchState, 118 | TextSensorStateResponse: TextSensorState, 119 | TextStateResponse: TextState, 120 | TimeStateResponse: TimeState, 121 | UpdateStateResponse: UpdateState, 122 | ValveStateResponse: ValveState, 123 | } 124 | 125 | LIST_ENTITIES_SERVICES_RESPONSE_TYPES: dict[Any, type[EntityInfo] | None] = { 126 | ListEntitiesAlarmControlPanelResponse: AlarmControlPanelInfo, 127 | ListEntitiesBinarySensorResponse: BinarySensorInfo, 128 | ListEntitiesButtonResponse: ButtonInfo, 129 | ListEntitiesCameraResponse: CameraInfo, 130 | ListEntitiesClimateResponse: ClimateInfo, 131 | ListEntitiesCoverResponse: CoverInfo, 132 | ListEntitiesDateResponse: DateInfo, 133 | ListEntitiesDateTimeResponse: DateTimeInfo, 134 | ListEntitiesEventResponse: EventInfo, 135 | ListEntitiesFanResponse: FanInfo, 136 | ListEntitiesLightResponse: LightInfo, 137 | ListEntitiesLockResponse: LockInfo, 138 | ListEntitiesMediaPlayerResponse: MediaPlayerInfo, 139 | ListEntitiesNumberResponse: NumberInfo, 140 | ListEntitiesSelectResponse: SelectInfo, 141 | ListEntitiesSensorResponse: SensorInfo, 142 | ListEntitiesServicesResponse: None, 143 | ListEntitiesSirenResponse: SirenInfo, 144 | ListEntitiesSwitchResponse: SwitchInfo, 145 | ListEntitiesTextResponse: TextInfo, 146 | ListEntitiesTextSensorResponse: TextSensorInfo, 147 | ListEntitiesTimeResponse: TimeInfo, 148 | ListEntitiesUpdateResponse: UpdateInfo, 149 | ListEntitiesValveResponse: ValveInfo, 150 | } 151 | -------------------------------------------------------------------------------- /aioesphomeapi/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/esphome/aioesphomeapi/a74676a3176f3d34d7a3345367f24eba0f160b61/aioesphomeapi/py.typed -------------------------------------------------------------------------------- /aioesphomeapi/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from asyncio import AbstractEventLoop, Task, get_running_loop 4 | from collections.abc import Coroutine 5 | import math 6 | import sys 7 | from typing import Any, TypeVar 8 | 9 | _T = TypeVar("_T") 10 | 11 | 12 | if sys.version_info[:2] < (3, 11): 13 | from async_timeout import timeout as asyncio_timeout 14 | else: 15 | from asyncio import timeout as asyncio_timeout 16 | 17 | 18 | def fix_float_single_double_conversion(value: float) -> float: 19 | """Fix precision for single-precision floats and return what was probably 20 | meant as a float. 21 | 22 | In ESPHome we work with single-precision floats internally for performance. 23 | But python uses double-precision floats, and when protobuf reads the message 24 | it's auto-converted to a double (which is possible losslessly). 25 | 26 | Unfortunately the float representation of 0.1 converted to a double is not the 27 | double representation of 0.1, but 0.10000000149011612. 28 | 29 | This methods tries to round to the closest decimal value that a float of this 30 | magnitude can accurately represent. 31 | """ 32 | if value == 0 or not math.isfinite(value): 33 | return value 34 | abs_val = abs(value) 35 | # assume ~7 decimals of precision for floats to be safe 36 | l10 = math.ceil(math.log10(abs_val)) 37 | prec = 7 - l10 38 | return round(value, prec) 39 | 40 | 41 | def host_is_name_part(address: str) -> bool: 42 | """Return True if a host is the name part.""" 43 | return "." not in address and ":" not in address 44 | 45 | 46 | def address_is_local(address: str) -> bool: 47 | """Return True if the address is a local address.""" 48 | return address.removesuffix(".").endswith(".local") 49 | 50 | 51 | def build_log_name( 52 | name: str | None, addresses: list[str], connected_address: str | None 53 | ) -> str: 54 | """Return a log name for a connection.""" 55 | preferred_address = connected_address 56 | for address in addresses: 57 | if (not name and address_is_local(address)) or host_is_name_part(address): 58 | name = address.partition(".")[0] 59 | elif not preferred_address: 60 | preferred_address = address 61 | if not preferred_address: 62 | return name or addresses[0] 63 | if ( 64 | name 65 | and name != preferred_address 66 | and not preferred_address.startswith(f"{name}.") 67 | ): 68 | return f"{name} @ {preferred_address}" 69 | return preferred_address 70 | 71 | 72 | if sys.version_info >= (3, 12, 0): 73 | 74 | def create_eager_task( 75 | coro: Coroutine[Any, Any, _T], 76 | *, 77 | name: str | None = None, 78 | loop: AbstractEventLoop | None = None, 79 | ) -> Task[_T]: 80 | """Create a task from a coroutine and schedule it to run immediately.""" 81 | return Task( 82 | coro, 83 | loop=loop or get_running_loop(), 84 | name=name, 85 | eager_start=True, # type: ignore[call-arg] 86 | ) 87 | 88 | else: 89 | 90 | def create_eager_task( 91 | coro: Coroutine[Any, Any, _T], 92 | *, 93 | name: str | None = None, 94 | loop: AbstractEventLoop | None = None, 95 | ) -> Task[_T]: 96 | """Create a task from a coroutine.""" 97 | return Task(coro, loop=loop or get_running_loop(), name=name) 98 | 99 | 100 | __all__ = ( 101 | "address_is_local", 102 | "asyncio_timeout", 103 | "build_log_name", 104 | "create_eager_task", 105 | "fix_float_single_double_conversion", 106 | "host_is_name_part", 107 | ) 108 | -------------------------------------------------------------------------------- /aioesphomeapi/zeroconf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import TYPE_CHECKING, Union 5 | 6 | from zeroconf import Zeroconf 7 | from zeroconf.asyncio import AsyncZeroconf 8 | 9 | ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf] 10 | 11 | _LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | class ZeroconfManager: 15 | """Manage the Zeroconf objects. 16 | 17 | This class is used to manage the Zeroconf objects. It is used to create 18 | the Zeroconf objects and to close them. It attempts to avoid creating 19 | a Zeroconf object unless one is actually needed. 20 | """ 21 | 22 | def __init__(self, zeroconf: ZeroconfInstanceType | None = None) -> None: 23 | """Initialize the ZeroconfManager.""" 24 | self._created = False 25 | self._aiozc: AsyncZeroconf | None = None 26 | if zeroconf is not None: 27 | self.set_instance(zeroconf) 28 | 29 | @property 30 | def has_instance(self) -> bool: 31 | """Return True if a Zeroconf instance is set.""" 32 | return self._aiozc is not None 33 | 34 | def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None: 35 | """Set the AsyncZeroconf instance.""" 36 | if self._aiozc: 37 | if isinstance(zc, AsyncZeroconf) and self._aiozc.zeroconf is zc.zeroconf: 38 | return 39 | if isinstance(zc, Zeroconf) and self._aiozc.zeroconf is zc: 40 | self._aiozc = AsyncZeroconf(zc=zc) 41 | return 42 | raise RuntimeError("Zeroconf instance already set to a different instance") 43 | self._aiozc = zc if isinstance(zc, AsyncZeroconf) else AsyncZeroconf(zc=zc) 44 | 45 | def _create_async_zeroconf(self) -> None: 46 | """Create an AsyncZeroconf instance.""" 47 | _LOGGER.debug("Creating new AsyncZeroconf instance") 48 | self._aiozc = AsyncZeroconf() 49 | self._created = True 50 | 51 | def get_async_zeroconf(self) -> AsyncZeroconf: 52 | """Get the AsyncZeroconf instance.""" 53 | if not self._aiozc: 54 | self._create_async_zeroconf() 55 | if TYPE_CHECKING: 56 | assert self._aiozc is not None 57 | return self._aiozc 58 | 59 | async def async_close(self) -> None: 60 | """Close the Zeroconf connection.""" 61 | if not self._created or not self._aiozc: 62 | return 63 | await self._aiozc.async_close() 64 | self._aiozc = None 65 | self._created = False 66 | -------------------------------------------------------------------------------- /bench/raw_ble_plain_text.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | from aioesphomeapi import APIConnection 4 | from aioesphomeapi._frame_helper import APIPlaintextFrameHelper 5 | from aioesphomeapi._frame_helper.packets import _cached_varuint_to_bytes 6 | from aioesphomeapi.api_pb2 import ( 7 | BluetoothLERawAdvertisement, 8 | BluetoothLERawAdvertisementsResponse, 9 | ) 10 | 11 | # cythonize -X language_level=3 -a -i aioesphomeapi/_frame_helper/plain_text.py 12 | # cythonize -X language_level=3 -a -i aioesphomeapi/_frame_helper/base.py 13 | # cythonize -X language_level=3 -a -i aioesphomeapi/connection.py 14 | 15 | adv = BluetoothLERawAdvertisementsResponse() 16 | fake_adv = BluetoothLERawAdvertisement( 17 | address=1, 18 | rssi=-86, 19 | address_type=2, 20 | data=( 21 | b"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75" 22 | b"657a2f686369302f64656c04010134000000e25389019500000001016f002500" 23 | b"00002f6f72672f626c75657a2f686369302f6465" 24 | ), 25 | ) 26 | for i in range(5): 27 | adv.advertisements.append(fake_adv) 28 | 29 | type_ = 93 30 | data = adv.SerializeToString() 31 | data = ( 32 | b"\0" + _cached_varuint_to_bytes(len(data)) + _cached_varuint_to_bytes(type_) + data 33 | ) 34 | 35 | 36 | class MockConnection(APIConnection): 37 | def __init__(self, *args, **kwargs): 38 | pass 39 | 40 | def process_packet(self, type_: int, data: bytes): 41 | pass 42 | 43 | def report_fatal_error(self, exc: Exception): 44 | raise exc 45 | 46 | 47 | connection = MockConnection() 48 | 49 | helper = APIPlaintextFrameHelper( 50 | connection=connection, client_info="my client", log_name="test" 51 | ) 52 | 53 | 54 | def process_incoming_msg(): 55 | helper.data_received(data) 56 | 57 | 58 | count = 3000000 59 | time = timeit.Timer(process_incoming_msg).timeit(count) 60 | print(f"Processed {count} bluetooth messages took {time} seconds") 61 | -------------------------------------------------------------------------------- /bench/raw_ble_plain_text_with_callback.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import timeit 3 | 4 | from aioesphomeapi import APIConnection 5 | from aioesphomeapi.api_pb2 import ( 6 | BluetoothLERawAdvertisement, 7 | BluetoothLERawAdvertisementsResponse, 8 | ) 9 | from aioesphomeapi.client import APIClient 10 | from aioesphomeapi.client_base import on_ble_raw_advertisement_response 11 | 12 | # cythonize -X language_level=3 -a -i aioesphomeapi/client_base.py 13 | # cythonize -X language_level=3 -a -i aioesphomeapi/connection.py 14 | 15 | 16 | class MockConnection(APIConnection): 17 | pass 18 | 19 | 20 | client = APIClient("fake.address", 6052, None) 21 | connection = MockConnection( 22 | client._params, lambda expected_disconnect: None, False, None 23 | ) 24 | 25 | 26 | def process_incoming_msg(): 27 | connection.process_packet( 28 | 93, 29 | b'\n\xb2\x01\x08\x01\x10\xab\x01\x18\x02"\xa8\x016c04010134000000' 30 | b"e25389019500000001016f00250000002f6f72672f626c75657a2f686369302" 31 | b"f64656c04010134000000e25389019500000001016f00250000002f6f72672f" 32 | b"626c75657a2f686369302f6465\n\xb2\x01\x08\x01\x10\xab\x01\x18\x02" 33 | b'"\xa8\x016c04010134000000e25389019500000001016f00250000002f6f726' 34 | b"72f626c75657a2f686369302f64656c04010134000000e253890195000000010" 35 | b"16f00250000002f6f72672f626c75657a2f686369302f6465\n\xb2\x01\x08" 36 | b'\x01\x10\xab\x01\x18\x02"\xa8\x016c04010134000000e25389019500000' 37 | b"001016f00250000002f6f72672f626c75657a2f686369302f64656c040101340" 38 | b"00000e25389019500000001016f00250000002f6f72672f626c75657a2f68636" 39 | b'9302f6465\n\xb2\x01\x08\x01\x10\xab\x01\x18\x02"\xa8\x016c040101' 40 | b"34000000e25389019500000001016f00250000002f6f72672f626c75657a2f68" 41 | b"6369302f64656c04010134000000e25389019500000001016f00250000002f6f" 42 | b"72672f626c75657a2f686369302f6465\n\xb2\x01\x08\x01\x10\xab\x01" 43 | b'\x18\x02"\xa8\x016c04010134000000e25389019500000001016f002500000' 44 | b"02f6f72672f626c75657a2f686369302f64656c04010134000000e2538901950" 45 | b"0000001016f00250000002f6f72672f626c75657a2f686369302f6465", 46 | ) 47 | 48 | 49 | def on_advertisements(msgs: list[BluetoothLERawAdvertisement]): 50 | pass 51 | 52 | 53 | connection.add_message_callback( 54 | partial(on_ble_raw_advertisement_response, on_advertisements), 55 | (BluetoothLERawAdvertisementsResponse,), 56 | ) 57 | 58 | count = 3000000 59 | time = timeit.Timer(process_incoming_msg).timeit(count) 60 | print(f"Processed {count} bluetooth messages took {time} seconds") 61 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.9 3 | show_error_codes = true 4 | 5 | strict = true 6 | warn_unreachable = true 7 | 8 | [mypy-async_timeout.*] 9 | ignore_missing_imports = True 10 | 11 | [mypy-noise.*] 12 | ignore_missing_imports = True 13 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | required-version = ">=0.5.0" 3 | exclude = [ 4 | "aioesphomeapi/api_pb2.py", 5 | "aioesphomeapi/api_options_pb2.py", 6 | ] 7 | 8 | [tool.ruff.lint] 9 | select = [ 10 | "ASYNC", # async rules 11 | "E", # pycodestyle 12 | "F", # pyflakes/autoflake 13 | "FLY", # flynt 14 | "FURB", # refurb 15 | "G", # flake8-logging-format 16 | "I", # isort 17 | "PERF", # Perflint 18 | "PIE", # flake8-pie 19 | "PL", # pylint 20 | "UP", # pyupgrade 21 | "RUF", # ruff 22 | "SIM", # flake8-SIM 23 | "SLOT", # flake8-slots 24 | "TID", # Tidy imports 25 | "TRY", # try rules 26 | "PERF", # performance 27 | ] 28 | 29 | ignore = [ 30 | "E501", # line too long 31 | "E721", # We want type() check for protobuf messages 32 | "PLR0911", # Too many return statements ({returns} > {max_returns}) 33 | "PLR0912", # Too many branches ({branches} > {max_branches}) 34 | "PLR0913", # Too many arguments to function call ({c_args} > {max_args}) 35 | "PLR0915", # Too many statements ({statements} > {max_statements}) 36 | "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable 37 | "PLW2901", # Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target 38 | "TRY003", # Too many to fix - Avoid specifying long messages outside the exception class 39 | "TID252", # Prefer absolute imports over relative imports from parent modules 40 | ] 41 | 42 | [tool.ruff.lint.isort] 43 | force-sort-within-sections = true 44 | known-first-party = [ 45 | "aioesphomeapi", "tests" 46 | ] 47 | combine-as-imports = true 48 | split-on-trailing-comma = false 49 | 50 | [build-system] 51 | requires = ['setuptools>=65.4.1', 'wheel', 'Cython>=3.0.2'] 52 | 53 | [tool.pytest.ini_options] 54 | asyncio_mode = "auto" 55 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs>=2.3.0 2 | async-interrupt>=1.2.0 3 | protobuf>=4 4 | zeroconf>=0.143.0,<1.0 5 | chacha20poly1305-reuseable>=0.13.2 6 | cryptography>=43.0.0 7 | noiseprotocol>=0.3.1,<1.0 8 | async-timeout>=4.0;python_version<'3.11' 9 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | pylint==3.3.7 2 | ruff==0.11.13 3 | flake8==7.2.0 4 | isort==6.0.1 5 | mypy==1.16.0 6 | types-protobuf==6.30.2.20250516 7 | pytest>=6.2.4,<9 8 | pytest-asyncio==0.26.0 9 | pytest-codspeed==3.2.0 10 | pytest-cov>=4.1.0 11 | pytest-timeout==2.4.0 12 | -------------------------------------------------------------------------------- /script/gen-protoc: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from pathlib import Path 5 | from subprocess import check_call 6 | 7 | root_dir = Path(__file__).absolute().parent.parent 8 | os.chdir(root_dir) 9 | 10 | check_call( 11 | [ 12 | "protoc", 13 | "--python_out=aioesphomeapi", 14 | "-I", 15 | "aioesphomeapi", 16 | "aioesphomeapi/api.proto", 17 | "aioesphomeapi/api_options.proto", 18 | ] 19 | ) 20 | 21 | # https://github.com/protocolbuffers/protobuf/issues/1491 22 | api_file = root_dir / "aioesphomeapi" / "api_pb2.py" 23 | content = api_file.read_text().replace( 24 | "import api_options_pb2 as api__options__pb2", 25 | "from . import api_options_pb2 as api__options__pb2", 26 | ) 27 | api_file.write_text(content) 28 | 29 | for fname in ["api_pb2.py", "api_options_pb2.py"]: 30 | file = root_dir / "aioesphomeapi" / fname 31 | content = "# type: ignore\n" + file.read_text() 32 | file.write_text(content) 33 | -------------------------------------------------------------------------------- /script/lint: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$(dirname "$0")/.." 4 | set -euxo pipefail 5 | 6 | black --safe aioesphomeapi tests 7 | pylint aioesphomeapi 8 | flake8 aioesphomeapi 9 | isort aioesphomeapi tests 10 | mypy aioesphomeapi 11 | pytest tests 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | # Following 4 for black compatibility 4 | # E501: line too long 5 | # W503: Line break occurred before a binary operator 6 | # E203: Whitespace before ':' 7 | # D202 No blank lines allowed after function docstring 8 | 9 | ignore = 10 | E501, 11 | W503, 12 | E203, 13 | D202, 14 | 15 | exclude = api_pb2.py, api_options_pb2.py 16 | 17 | [bdist_wheel] 18 | universal = 1 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """aioesphomeapi setup script.""" 3 | 4 | import contextlib 5 | from distutils.command.build_ext import build_ext 6 | import os 7 | from typing import Any 8 | 9 | from setuptools import find_packages, setup 10 | 11 | try: 12 | from setuptools import Extension 13 | except ImportError: 14 | from distutils.core import Extension 15 | 16 | TO_CYTHONIZE = [ 17 | "aioesphomeapi/client_base.py", 18 | "aioesphomeapi/connection.py", 19 | "aioesphomeapi/_frame_helper/base.py", 20 | "aioesphomeapi/_frame_helper/noise.py", 21 | "aioesphomeapi/_frame_helper/noise_encryption.py", 22 | "aioesphomeapi/_frame_helper/packets.py", 23 | "aioesphomeapi/_frame_helper/plain_text.py", 24 | "aioesphomeapi/_frame_helper/pack.pyx", 25 | ] 26 | 27 | EXTENSIONS = [ 28 | Extension( 29 | ext.removesuffix(".py").removesuffix(".pyx").replace("/", "."), 30 | [ext], 31 | language="c", 32 | extra_compile_args=["-O3", "-g0"], 33 | ) 34 | for ext in TO_CYTHONIZE 35 | ] 36 | 37 | 38 | here = os.path.abspath(os.path.dirname(__file__)) 39 | 40 | with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file: 41 | long_description = readme_file.read() 42 | 43 | 44 | VERSION = "32.2.2" 45 | PROJECT_NAME = "aioesphomeapi" 46 | PROJECT_PACKAGE_NAME = "aioesphomeapi" 47 | PROJECT_LICENSE = "MIT" 48 | PROJECT_AUTHOR = "Otto Winter" 49 | PROJECT_COPYRIGHT = " 2019-2020, Otto Winter" 50 | PROJECT_URL = "https://esphome.io/" 51 | PROJECT_EMAIL = "esphome@nabucasa.com" 52 | 53 | PROJECT_GITHUB_USERNAME = "esphome" 54 | PROJECT_GITHUB_REPOSITORY = "aioesphomeapi" 55 | 56 | PYPI_URL = f"https://pypi.python.org/pypi/{PROJECT_PACKAGE_NAME}" 57 | GITHUB_PATH = f"{PROJECT_GITHUB_USERNAME}/{PROJECT_GITHUB_REPOSITORY}" 58 | GITHUB_URL = f"https://github.com/{GITHUB_PATH}" 59 | 60 | DOWNLOAD_URL = f"{GITHUB_URL}/archive/{VERSION}.zip" 61 | 62 | with open(os.path.join(here, "requirements/base.txt")) as requirements_txt: 63 | REQUIRES = requirements_txt.read().splitlines() 64 | 65 | pkgs = find_packages(exclude=["tests", "tests.*"]) 66 | 67 | setup_kwargs = { 68 | "name": PROJECT_PACKAGE_NAME, 69 | "version": VERSION, 70 | "url": PROJECT_URL, 71 | "download_url": DOWNLOAD_URL, 72 | "author": PROJECT_AUTHOR, 73 | "author_email": PROJECT_EMAIL, 74 | "description": "Python API for interacting with ESPHome devices.", 75 | "long_description": long_description, 76 | "license": PROJECT_LICENSE, 77 | "packages": pkgs, 78 | "exclude_package_data": {pkg: ["*.c"] for pkg in pkgs}, 79 | "include_package_data": True, 80 | "zip_safe": False, 81 | "install_requires": REQUIRES, 82 | "python_requires": ">=3.9", 83 | "test_suite": "tests", 84 | "entry_points": { 85 | "console_scripts": [ 86 | "aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point", 87 | "aioesphomeapi-discover=aioesphomeapi.discover:cli_entry_point", 88 | ], 89 | }, 90 | } 91 | 92 | 93 | class OptionalBuildExt(build_ext): 94 | def build_extensions(self) -> None: 95 | with contextlib.suppress(Exception): 96 | super().build_extensions() 97 | 98 | 99 | def cythonize_if_available(setup_kwargs: dict[str, Any]) -> None: 100 | if os.environ.get("SKIP_CYTHON"): 101 | return 102 | try: 103 | from Cython.Build import cythonize 104 | 105 | setup_kwargs.update( 106 | dict( 107 | ext_modules=cythonize( 108 | EXTENSIONS, 109 | compiler_directives={"language_level": "3"}, # Python 3 110 | ), 111 | cmdclass=dict(build_ext=OptionalBuildExt), 112 | ) 113 | ) 114 | except Exception: 115 | if os.environ.get("REQUIRE_CYTHON"): 116 | raise 117 | 118 | 119 | cythonize_if_available(setup_kwargs) 120 | 121 | setup(**setup_kwargs) 122 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Init tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | 7 | logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG) 8 | -------------------------------------------------------------------------------- /tests/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # Enable debug logging is not on for benchmarks 4 | logging.getLogger("aioesphomeapi").setLevel(logging.WARNING) 5 | -------------------------------------------------------------------------------- /tests/benchmarks/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(autouse=True) 7 | def no_debug_logging(): 8 | # Enable debug logging is not on for benchmarks 9 | aioesphomeapi_logger = logging.getLogger("aioesphomeapi") 10 | original_level = aioesphomeapi_logger.level 11 | aioesphomeapi_logger.setLevel(logging.WARNING) 12 | yield 13 | aioesphomeapi_logger.setLevel(original_level) 14 | -------------------------------------------------------------------------------- /tests/benchmarks/test_bluetooth.py: -------------------------------------------------------------------------------- 1 | """Benchmarks.""" 2 | 3 | from functools import partial 4 | 5 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped] 6 | 7 | from aioesphomeapi import APIConnection 8 | from aioesphomeapi._frame_helper.packets import _cached_varuint_to_bytes 9 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper 10 | from aioesphomeapi.api_pb2 import ( 11 | BluetoothLERawAdvertisement, 12 | BluetoothLERawAdvertisementsResponse, 13 | ) 14 | from aioesphomeapi.client import APIClient 15 | 16 | 17 | async def test_raw_ble_plain_text_with_callback(benchmark: BenchmarkFixture) -> None: 18 | """Benchmark raw BLE plaintext with callback.""" 19 | 20 | class MockConnection(APIConnection): 21 | pass 22 | 23 | client = APIClient("fake.address", 6052, None) 24 | connection = MockConnection( 25 | client._params, lambda expected_disconnect: None, False, None 26 | ) 27 | 28 | process_incoming_msg = partial( 29 | connection.process_packet, 30 | 93, 31 | b'\n$\x08\xe3\x8a\x83\xad\x9c\xa3\x1d\x10\xbd\x01\x18\x01"\x15\x02\x01\x1a' 32 | b"\x02\n\x06\x0e\xffL\x00\x0f\x05\x90\x00\xb5B\x9c\x10\x02)\x04\n!" 33 | b'\x08\x9e\x9a\xb1\xfc\x9e\x890\x10\xbf\x01"\x14\x02\x01\x06\x10\xff\xa9\x0b' 34 | b"\x01\x05\x00\x0b\x04\x18\n\x1cM\x8c\xefI\xc0\n.\x08\x9f\x89\x85\xe6" 35 | b'\xf3\xe8\x17\x10\x8d\x01\x18\x01"\x1f\x02\x01\x02\x14\xff\xa7' 36 | b"\x05\x06\x00\x12 %\x00\xca\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x02\n" 37 | b'\x0f\x03\x03\x07\xfe\n\x1f\x08\x9e\xbf\xb5\x87\x98\xce7\x10_\x18\x01"' 38 | b"\x11\x02\x01\x06\r\xffi\t\xdeq\x80\xed_\x9e\x0bC\x08 \n \x08\xc6\x8a\xa9" 39 | b'\xed\xb9\xc4>\x10\xab\x01\x18\x01"\x11\x02\x01\x06\x07\xff\t\x04\x8c\x01' 40 | b'a\x01\x05\tRZSS\n \x08\xd7\xc6\xe8\xe8\x91\xb85\x10\xa5\x01\x18\x01"' 41 | b"\x11\x02\x01\x06\x07\xff\t\x04\x8c\x01`\x01\x05\tRZSS\n-\x08\xca\xb0\x91" 42 | b'\xf4\xbc\xe6<\x10}\x18\x01"\x1f\x02\x01\x04\x03\x03\x07\xfe\x14\xff\xa7' 43 | b"\x05\x06\x00\x12 %\x00\xca\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x02\n" 44 | b'\x00\n)\x08\xf9\xdd\x95\xac\xb9\x95\r\x10\x87\x01"\x1c\x02\x01\x06\x03' 45 | b"\x03\x12\x18\x10\tLOOKin_98F330B4\x03\x19\xc1\x03", 46 | ) 47 | 48 | def on_advertisements(msgs: list[BluetoothLERawAdvertisement]): 49 | """Callback for advertisements.""" 50 | 51 | connection.add_message_callback( 52 | on_advertisements, 53 | (BluetoothLERawAdvertisementsResponse,), 54 | ) 55 | 56 | benchmark(process_incoming_msg) 57 | 58 | 59 | async def test_raw_ble_plain_text(benchmark: BenchmarkFixture) -> None: 60 | """Benchmark raw BLE plaintext.""" 61 | adv = BluetoothLERawAdvertisementsResponse() 62 | fake_adv = BluetoothLERawAdvertisement( 63 | address=1, 64 | rssi=-86, 65 | address_type=2, 66 | data=( 67 | b"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75" 68 | b"657a2f686369302f64656c04010134000000e25389019500000001016f002500" 69 | b"00002f6f72672f626c75657a2f686369302f6465" 70 | ), 71 | ) 72 | for i in range(5): 73 | adv.advertisements.append(fake_adv) 74 | 75 | type_ = 93 76 | data = adv.SerializeToString() 77 | data = ( 78 | b"\0" 79 | + _cached_varuint_to_bytes(len(data)) 80 | + _cached_varuint_to_bytes(type_) 81 | + data 82 | ) 83 | 84 | class MockConnection(APIConnection): 85 | def __init__(self, *args, **kwargs): 86 | """Initialize the connection.""" 87 | 88 | def process_packet(self, type_: int, data: bytes): 89 | """Process a packet.""" 90 | 91 | def report_fatal_error(self, exc: Exception): 92 | raise exc 93 | 94 | connection = MockConnection() 95 | 96 | helper = APIPlaintextFrameHelper( 97 | connection=connection, client_info="my client", log_name="test" 98 | ) 99 | 100 | process_incoming_msg = partial(helper.data_received, data) 101 | 102 | benchmark(process_incoming_msg) 103 | 104 | 105 | async def test_raw_ble_plain_text_different_advs(benchmark: BenchmarkFixture) -> None: 106 | """Benchmark raw BLE plaintext with different advertisements.""" 107 | data = ( 108 | b"\x01\x01\x07\x98\xaa7\xd7\xc5s\xe2\xdd\xc2\x96aG\xb1\xac:\xd3\xde" 109 | b"\x18\xefz\x00\xca@\xa9\xc8\xeb-\xe6`}\xa1\x00=\xae\x0e\xee\xc4Iy\xd6\x95" 110 | b"c\xed\x12S\xed\x14 \xa4\x9c&VcE\x0c=\xa8?\xaa\xe851\xdc=\xd6\xeeg\xffb" 111 | b"\x9a\xf5\xc9\xf6\x0b\r\xb9~\x11\xe3p$\xd9\xa9k\xcd\x1f\x03\x87f\xb8\x0c!\xac" 112 | b"\xb8:\xf5\x15jC@&\xf1\x13\xca\x89\x96r\xf9\xbd\xf1\xfe\xa0-\xfa\x87\x0cP" 113 | b"\xa7J+\xbaD,/\xf6\xc3\xf7\\\x1d\xcb#\xda@\xe0\n\xa7\xe0\xf0a\x16\xfb" 114 | b'\xb5\xfc\\\xbd1\xfb\xd25\x04\x94\x1e/"E\x90,J\xfd\x0f\xbc\xe5>\x96\xba' 115 | b"\x1bc\xa8\x1eQ\xbd|\xd9\xef\xc1\xffr\x04\x15i7\xea\x8clm`\xaa\x034" 116 | b"\x0b\xe5\xfe\x06\xfc\xb9\x9fc\xddE\xc93\xc0\x13\xe3\xe3$\xb1\xf2\x93" 117 | b"\xdb\x1dJ\xbf\x08edi.|\x93\x18\x7f\x83\x7fx\xbe\x01I\x1b\x8c\xe9\xf2\x06" 118 | b"\x8e\x08\xbe\xb0R&^7[\x1f4\x8f\xe0\xa1jf\xefL\x1b\x1el\xbb\x1c\x99" 119 | b"\x0f\x94r\xc2=\x10" 120 | ) 121 | 122 | type_ = 93 123 | data = ( 124 | b"\0" 125 | + _cached_varuint_to_bytes(len(data)) 126 | + _cached_varuint_to_bytes(type_) 127 | + data 128 | ) 129 | 130 | class MockConnection(APIConnection): 131 | def __init__(self, *args, **kwargs): 132 | """Initialize the connection.""" 133 | 134 | def process_packet(self, type_: int, data: bytes): 135 | """Process a packet.""" 136 | 137 | def report_fatal_error(self, exc: Exception): 138 | raise exc 139 | 140 | connection = MockConnection() 141 | 142 | helper = APIPlaintextFrameHelper( 143 | connection=connection, client_info="my client", log_name="test" 144 | ) 145 | 146 | process_incoming_msg = partial(helper.data_received, data) 147 | 148 | benchmark(process_incoming_msg) 149 | 150 | 151 | async def test_multiple_ble_adv_messages_single_read( 152 | benchmark: BenchmarkFixture, 153 | ) -> None: 154 | """Benchmark multiple raw ble advertisement messages in a single read.""" 155 | data = ( 156 | b"\x01\x01\x07\x98\xaa7\xd7\xc5s\xe2\xdd\xc2\x96aG\xb1\xac:\xd3\xde" 157 | b"\x18\xefz\x00\xca@\xa9\xc8\xeb-\xe6`}\xa1\x00=\xae\x0e\xee\xc4Iy\xd6\x95" 158 | b"c\xed\x12S\xed\x14 \xa4\x9c&VcE\x0c=\xa8?\xaa\xe851\xdc=\xd6\xeeg\xffb" 159 | b"\x9a\xf5\xc9\xf6\x0b\r\xb9~\x11\xe3p$\xd9\xa9k\xcd\x1f\x03\x87f\xb8\x0c!\xac" 160 | b"\xb8:\xf5\x15jC@&\xf1\x13\xca\x89\x96r\xf9\xbd\xf1\xfe\xa0-\xfa\x87\x0cP" 161 | b"\xa7J+\xbaD,/\xf6\xc3\xf7\\\x1d\xcb#\xda@\xe0\n\xa7\xe0\xf0a\x16\xfb" 162 | b'\xb5\xfc\\\xbd1\xfb\xd25\x04\x94\x1e/"E\x90,J\xfd\x0f\xbc\xe5>\x96\xba' 163 | b"\x1bc\xa8\x1eQ\xbd|\xd9\xef\xc1\xffr\x04\x15i7\xea\x8clm`\xaa\x034" 164 | b"\x0b\xe5\xfe\x06\xfc\xb9\x9fc\xddE\xc93\xc0\x13\xe3\xe3$\xb1\xf2\x93" 165 | b"\xdb\x1dJ\xbf\x08edi.|\x93\x18\x7f\x83\x7fx\xbe\x01I\x1b\x8c\xe9\xf2\x06" 166 | b"\x8e\x08\xbe\xb0R&^7[\x1f4\x8f\xe0\xa1jf\xefL\x1b\x1el\xbb\x1c\x99" 167 | b"\x0f\x94r\xc2=\x10" 168 | ) 169 | 170 | type_ = 93 171 | data = ( 172 | b"\0" 173 | + _cached_varuint_to_bytes(len(data)) 174 | + _cached_varuint_to_bytes(type_) 175 | + data 176 | ) 177 | 178 | class MockConnection(APIConnection): 179 | def __init__(self, *args, **kwargs): 180 | """Initialize the connection.""" 181 | 182 | def process_packet(self, type_: int, data: bytes): 183 | """Process a packet.""" 184 | 185 | def report_fatal_error(self, exc: Exception): 186 | raise exc 187 | 188 | connection = MockConnection() 189 | 190 | helper = APIPlaintextFrameHelper( 191 | connection=connection, client_info="my client", log_name="test" 192 | ) 193 | 194 | process_incoming_msg = partial(helper.data_received, data * 5) 195 | 196 | benchmark(process_incoming_msg) 197 | -------------------------------------------------------------------------------- /tests/benchmarks/test_noise.py: -------------------------------------------------------------------------------- 1 | """Benchmarks for noise.""" 2 | 3 | import asyncio 4 | import base64 5 | from collections.abc import Iterable 6 | 7 | import pytest 8 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped] 9 | 10 | from aioesphomeapi._frame_helper.noise_encryption import EncryptCipher 11 | 12 | from ..common import ( 13 | MockAPINoiseFrameHelper, 14 | _extract_encrypted_payload_from_handshake, 15 | _make_encrypted_packet, 16 | _make_mock_connection, 17 | _make_noise_handshake_pkt, 18 | _make_noise_hello_pkt, 19 | _mock_responder_proto, 20 | mock_data_received, 21 | ) 22 | 23 | 24 | @pytest.mark.parametrize("payload_size", [0, 64, 128, 1024, 16 * 1024]) 25 | async def test_noise_messages(benchmark: BenchmarkFixture, payload_size: int) -> None: 26 | """Benchmark raw noise protocol.""" 27 | noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" 28 | psk_bytes = base64.b64decode(noise_psk) 29 | writes = [] 30 | 31 | def _writelines(data: Iterable[bytes]): 32 | writes.append(b"".join(data)) 33 | 34 | connection, packets = _make_mock_connection() 35 | 36 | helper = MockAPINoiseFrameHelper( 37 | connection=connection, 38 | noise_psk=noise_psk, 39 | expected_name="servicetest", 40 | expected_mac=None, 41 | client_info="my client", 42 | log_name="test", 43 | writer=_writelines, 44 | ) 45 | 46 | proto = _mock_responder_proto(psk_bytes) 47 | 48 | await asyncio.sleep(0) # let the task run to read the hello packet 49 | 50 | assert len(writes) == 1 51 | handshake_pkt = writes.pop() 52 | 53 | encrypted_payload = _extract_encrypted_payload_from_handshake(handshake_pkt) 54 | decrypted = proto.read_message(encrypted_payload) 55 | assert decrypted == b"" 56 | 57 | hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") 58 | mock_data_received(helper, hello_pkt_with_header) 59 | 60 | handshake_with_header = _make_noise_handshake_pkt(proto) 61 | mock_data_received(helper, handshake_with_header) 62 | 63 | assert not writes 64 | 65 | await helper.ready_future 66 | helper.write_packets([(1, b"to device")], True) 67 | encrypted_packet = writes.pop() 68 | header = encrypted_packet[0:1] 69 | assert header == b"\x01" 70 | pkg_length_high = encrypted_packet[1] 71 | pkg_length_low = encrypted_packet[2] 72 | pkg_length = (pkg_length_high << 8) + pkg_length_low 73 | assert len(encrypted_packet) == 3 + pkg_length 74 | 75 | helper.write_packets([(1, b"to device")], True) 76 | 77 | def _empty_writelines(data: Iterable[bytes]): 78 | """Empty writelines.""" 79 | 80 | helper._writelines = _empty_writelines 81 | 82 | payload = b"x" * payload_size 83 | encrypt_cipher = EncryptCipher(proto.noise_protocol.cipher_state_encrypt) 84 | 85 | @benchmark 86 | def process_encrypted_packets(): 87 | for _ in range(100): 88 | helper.data_received(_make_encrypted_packet(encrypt_cipher, 42, payload)) 89 | 90 | helper.close() 91 | -------------------------------------------------------------------------------- /tests/benchmarks/test_requests.py: -------------------------------------------------------------------------------- 1 | """Benchmarks.""" 2 | 3 | import asyncio 4 | 5 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped] 6 | 7 | from aioesphomeapi import APIConnection 8 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper 9 | from aioesphomeapi.client import APIClient 10 | 11 | 12 | def test_sending_light_command_request_with_bool( 13 | benchmark: BenchmarkFixture, 14 | api_client: tuple[ 15 | APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper 16 | ], 17 | ) -> None: 18 | client, connection, transport, protocol = api_client 19 | 20 | connection._frame_helper._writelines = lambda lines: None 21 | 22 | @benchmark 23 | def send_request(): 24 | client.light_command(1, True) 25 | 26 | 27 | def test_sending_empty_light_command_request( 28 | benchmark: BenchmarkFixture, 29 | api_client: tuple[ 30 | APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper 31 | ], 32 | ) -> None: 33 | client, connection, transport, protocol = api_client 34 | 35 | connection._frame_helper._writelines = lambda lines: None 36 | 37 | @benchmark 38 | def send_request(): 39 | client.light_command(1) 40 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from collections.abc import Awaitable 5 | from datetime import datetime, timezone 6 | from functools import partial 7 | import time 8 | from typing import Any, Callable 9 | from unittest.mock import AsyncMock, MagicMock, patch 10 | 11 | from google.protobuf import message 12 | from noise.connection import NoiseConnection # type: ignore[import-untyped] 13 | from zeroconf import Zeroconf 14 | from zeroconf.asyncio import AsyncZeroconf 15 | 16 | from aioesphomeapi import APIClient, APIConnection 17 | from aioesphomeapi._frame_helper.noise import APINoiseFrameHelper 18 | from aioesphomeapi._frame_helper.noise_encryption import ( 19 | ESPHOME_NOISE_BACKEND, 20 | EncryptCipher, 21 | ) 22 | from aioesphomeapi._frame_helper.packets import _cached_varuint_to_bytes 23 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper 24 | from aioesphomeapi.api_pb2 import ( 25 | ConnectResponse, 26 | HelloResponse, 27 | PingRequest, 28 | PingResponse, 29 | ) 30 | from aioesphomeapi.client import ConnectionParams 31 | from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, SocketClosedAPIError 32 | from aioesphomeapi.zeroconf import ZeroconfManager 33 | 34 | UTC = timezone.utc 35 | _MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution 36 | # We use a partial here since it is implemented in native code 37 | # and avoids the global lookup of UTC 38 | utcnow: partial[datetime] = partial(datetime.now, UTC) 39 | utcnow.__doc__ = "Get now in UTC time." 40 | 41 | PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} 42 | 43 | 44 | PREAMBLE = b"\x00" 45 | 46 | NOISE_HELLO = b"\x01\x00\x00" 47 | KEEP_ALIVE_INTERVAL = 15.0 48 | 49 | 50 | def get_mock_connection_params() -> ConnectionParams: 51 | return ConnectionParams( 52 | addresses=["fake.address"], 53 | port=6052, 54 | password=None, 55 | client_info="Tests client", 56 | keepalive=KEEP_ALIVE_INTERVAL, 57 | zeroconf_manager=ZeroconfManager(), 58 | noise_psk=None, 59 | expected_name=None, 60 | expected_mac=None, 61 | ) 62 | 63 | 64 | def mock_data_received( 65 | protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes 66 | ) -> None: 67 | """Mock data received on the protocol.""" 68 | try: 69 | protocol.data_received(data) 70 | except Exception as err: # pylint: disable=broad-except 71 | loop = asyncio.get_running_loop() 72 | loop.call_soon( 73 | protocol.connection_lost, 74 | err, 75 | ) 76 | 77 | 78 | def get_mock_zeroconf() -> MagicMock: 79 | with patch("zeroconf.Zeroconf.start"): 80 | zc = Zeroconf() 81 | zc.close = MagicMock() 82 | return zc 83 | 84 | 85 | def get_mock_async_zeroconf() -> AsyncZeroconf: 86 | aiozc = AsyncZeroconf(zc=get_mock_zeroconf()) 87 | aiozc.async_close = AsyncMock() 88 | return aiozc 89 | 90 | 91 | class Estr(str): 92 | """A subclassed string.""" 93 | 94 | __slots__ = () 95 | 96 | 97 | def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]: 98 | type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__] 99 | bytes_ = msg.SerializeToString() 100 | return [ 101 | b"\0", 102 | _cached_varuint_to_bytes(len(bytes_)), 103 | _cached_varuint_to_bytes(type_), 104 | bytes_, 105 | ] 106 | 107 | 108 | def generate_plaintext_packet(msg: message.Message) -> bytes: 109 | return b"".join(generate_split_plaintext_packet(msg)) 110 | 111 | 112 | def as_utc(dattim: datetime) -> datetime: 113 | """Return a datetime as UTC time.""" 114 | if dattim.tzinfo == UTC: 115 | return dattim 116 | return dattim.astimezone(UTC) 117 | 118 | 119 | def async_fire_time_changed( 120 | datetime_: datetime | None = None, fire_all: bool = False 121 | ) -> None: 122 | """Fire a time changed event at an exact microsecond. 123 | 124 | Consider that it is not possible to actually achieve an exact 125 | microsecond in production as the event loop is not precise enough. 126 | If your code relies on this level of precision, consider a different 127 | approach, as this is only for testing. 128 | """ 129 | loop = asyncio.get_running_loop() 130 | utc_datetime = datetime.now(UTC) if datetime_ is None else as_utc(datetime_) 131 | 132 | timestamp = utc_datetime.timestamp() 133 | for task in list(loop._scheduled): 134 | if not isinstance(task, asyncio.TimerHandle): 135 | continue 136 | if task.cancelled(): 137 | continue 138 | 139 | mock_seconds_into_future = timestamp - time.time() 140 | future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION) 141 | 142 | if fire_all or mock_seconds_into_future >= future_seconds: 143 | task._run() 144 | task.cancel() 145 | 146 | 147 | async def connect(conn: APIConnection, login: bool = True): 148 | """Wrapper for connection logic to do both parts.""" 149 | await conn.start_resolve_host() 150 | await conn.start_connection() 151 | await conn.finish_connection(login=login) 152 | 153 | 154 | async def connect_client( 155 | client: APIClient, 156 | login: bool = True, 157 | on_stop: Callable[[bool], Awaitable[None]] | None = None, 158 | ) -> None: 159 | """Wrapper for connection logic to do both parts.""" 160 | await client.start_resolve_host(on_stop=on_stop) 161 | await client.start_connection() 162 | await client.finish_connection(login=login) 163 | 164 | 165 | def send_plaintext_hello( 166 | protocol: APIPlaintextFrameHelper, 167 | major: int | None = None, 168 | minor: int | None = None, 169 | ) -> None: 170 | hello_response: message.Message = HelloResponse() 171 | hello_response.api_version_major = 1 if major is None else major 172 | hello_response.api_version_minor = 9 if minor is None else minor 173 | hello_response.name = "fake" 174 | protocol.data_received(generate_plaintext_packet(hello_response)) 175 | 176 | 177 | def send_plaintext_connect_response( 178 | protocol: APIPlaintextFrameHelper, invalid_password: bool 179 | ) -> None: 180 | connect_response: message.Message = ConnectResponse() 181 | connect_response.invalid_password = invalid_password 182 | protocol.data_received(generate_plaintext_packet(connect_response)) 183 | 184 | 185 | def send_ping_response(protocol: APIPlaintextFrameHelper) -> None: 186 | ping_response: message.Message = PingResponse() 187 | protocol.data_received(generate_plaintext_packet(ping_response)) 188 | 189 | 190 | def send_ping_request(protocol: APIPlaintextFrameHelper) -> None: 191 | ping_request: message.Message = PingRequest() 192 | protocol.data_received(generate_plaintext_packet(ping_request)) 193 | 194 | 195 | def get_mock_protocol(conn: APIConnection): 196 | protocol = APIPlaintextFrameHelper( 197 | connection=conn, 198 | client_info="mock", 199 | log_name="mock_device", 200 | ) 201 | transport = MagicMock() 202 | protocol.connection_made(transport) 203 | return protocol 204 | 205 | 206 | def _create_mock_transport_protocol( 207 | transport: asyncio.Transport, 208 | connected: asyncio.Event, 209 | create_func: Callable[[], APIPlaintextFrameHelper], 210 | **kwargs, 211 | ) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]: 212 | protocol: APIPlaintextFrameHelper = create_func() 213 | protocol.connection_made(transport) 214 | connected.set() 215 | return transport, protocol 216 | 217 | 218 | def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes: 219 | noise_hello = handshake_pkt[0:3] 220 | pkt_header = handshake_pkt[3:6] 221 | assert noise_hello == NOISE_HELLO 222 | assert pkt_header[0] == 1 # type 223 | pkg_length_high = pkt_header[1] 224 | pkg_length_low = pkt_header[2] 225 | pkg_length = (pkg_length_high << 8) + pkg_length_low 226 | assert pkg_length == 49 227 | noise_prefix = handshake_pkt[6:7] 228 | assert noise_prefix == b"\x00" 229 | return handshake_pkt[7:] 230 | 231 | 232 | def _make_noise_hello_pkt(hello_pkt: bytes) -> bytes: 233 | """Make a noise hello packet.""" 234 | preamble = 1 235 | hello_pkg_length = len(hello_pkt) 236 | hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF 237 | hello_pkg_length_low = hello_pkg_length & 0xFF 238 | hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) 239 | return hello_header + hello_pkt 240 | 241 | 242 | def _make_noise_handshake_pkt(proto: NoiseConnection) -> bytes: 243 | handshake = proto.write_message(b"") 244 | handshake_pkt = b"\x00" + handshake 245 | preamble = 1 246 | handshake_pkg_length = len(handshake_pkt) 247 | handshake_pkg_length_high = (handshake_pkg_length >> 8) & 0xFF 248 | handshake_pkg_length_low = handshake_pkg_length & 0xFF 249 | handshake_header = bytes( 250 | (preamble, handshake_pkg_length_high, handshake_pkg_length_low) 251 | ) 252 | 253 | return handshake_header + handshake_pkt 254 | 255 | 256 | def _make_encrypted_packet( 257 | cipher: EncryptCipher, msg_type: int, payload: bytes 258 | ) -> bytes: 259 | msg_type = 42 260 | msg_type_high = (msg_type >> 8) & 0xFF 261 | msg_type_low = msg_type & 0xFF 262 | msg_length = len(payload) 263 | msg_length_high = (msg_length >> 8) & 0xFF 264 | msg_length_low = msg_length & 0xFF 265 | msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low)) 266 | encrypted_payload = cipher.encrypt(msg_header + payload) 267 | return _make_encrypted_packet_from_encrypted_payload(encrypted_payload) 268 | 269 | 270 | def _make_encrypted_packet_from_encrypted_payload(encrypted_payload: bytes) -> bytes: 271 | preamble = 1 272 | encrypted_pkg_length = len(encrypted_payload) 273 | encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF 274 | encrypted_pkg_length_low = encrypted_pkg_length & 0xFF 275 | encrypted_header = bytes( 276 | (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low) 277 | ) 278 | return encrypted_header + encrypted_payload 279 | 280 | 281 | def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection: 282 | proto = NoiseConnection.from_name( 283 | b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND 284 | ) 285 | proto.set_as_responder() 286 | proto.set_psks(psk_bytes) 287 | proto.set_prologue(b"NoiseAPIInit\x00\x00") 288 | proto.start_handshake() 289 | return proto 290 | 291 | 292 | def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]: 293 | """Make a mock connection.""" 294 | packets: list[tuple[int, bytes]] = [] 295 | 296 | class MockConnection(APIConnection): 297 | def __init__(self, *args: Any, **kwargs: Any) -> None: 298 | """Swallow args.""" 299 | super().__init__( 300 | get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs 301 | ) 302 | 303 | def process_packet(self, type_: int, data: bytes): 304 | packets.append((type_, data)) 305 | 306 | connection = MockConnection() 307 | return connection, packets 308 | 309 | 310 | class MockAPINoiseFrameHelper(APINoiseFrameHelper): 311 | def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None: 312 | """Swallow args.""" 313 | super().__init__(*args, **kwargs) 314 | transport = MagicMock() 315 | transport.writelines = writer or MagicMock() 316 | self.__transport = transport 317 | self.connection_made(transport) 318 | 319 | def connection_made(self, transport: Any) -> None: 320 | return super().connection_made(self.__transport) 321 | 322 | def mock_write_frame(self, frame: bytes) -> None: 323 | """Write a packet to the socket. 324 | 325 | The entire packet must be written in a single call to write. 326 | """ 327 | frame_len = len(frame) 328 | header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) 329 | try: 330 | self._writelines([header, frame]) 331 | except (RuntimeError, ConnectionResetError, OSError) as err: 332 | raise SocketClosedAPIError( 333 | f"{self._log_name}: Error while writing data: {err}" 334 | ) from err 335 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Test fixtures.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | from collections.abc import Generator 7 | from contextlib import contextmanager 8 | from dataclasses import replace 9 | from functools import partial 10 | import reprlib 11 | import socket 12 | from unittest.mock import AsyncMock, MagicMock, create_autospec, patch 13 | 14 | import pytest 15 | import pytest_asyncio 16 | 17 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper 18 | from aioesphomeapi.client import APIClient, ConnectionParams 19 | from aioesphomeapi.connection import APIConnection 20 | from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr 21 | 22 | from .common import ( 23 | _create_mock_transport_protocol, 24 | connect, 25 | connect_client, 26 | get_mock_async_zeroconf, 27 | get_mock_connection_params, 28 | send_plaintext_hello, 29 | ) 30 | 31 | _MOCK_RESOLVE_RESULT = [ 32 | AddrInfo( 33 | family=socket.AF_INET, 34 | type=socket.SOCK_STREAM, 35 | proto=socket.IPPROTO_TCP, 36 | sockaddr=IPv4Sockaddr("10.0.0.512", 6052), 37 | ) 38 | ] 39 | 40 | 41 | class PatchableAPIConnection(APIConnection): 42 | """Patchable APIConnection for testing.""" 43 | 44 | 45 | class PatchableAPIClient(APIClient): 46 | """Patchable APIClient for testing.""" 47 | 48 | 49 | @pytest.fixture 50 | def async_zeroconf(): 51 | return get_mock_async_zeroconf() 52 | 53 | 54 | @pytest.fixture 55 | def resolve_host() -> Generator[AsyncMock]: 56 | with patch("aioesphomeapi.host_resolver.async_resolve_host") as func: 57 | func.return_value = _MOCK_RESOLVE_RESULT 58 | yield func 59 | 60 | 61 | @pytest.fixture 62 | async def patchable_api_client() -> APIClient: 63 | cli = PatchableAPIClient( 64 | address="127.0.0.1", 65 | port=6052, 66 | password=None, 67 | ) 68 | return cli 69 | 70 | 71 | @pytest.fixture 72 | async def auth_client(): 73 | client = PatchableAPIClient( 74 | address="fake.address", 75 | port=6052, 76 | password=None, 77 | ) 78 | mock_connection = PatchableAPIConnection( 79 | params=client._params, 80 | on_stop=client._on_stop, 81 | debug_enabled=False, 82 | log_name=client.log_name, 83 | ) 84 | mock_connection.is_connected = True 85 | with patch.object(client, "_connection", mock_connection): 86 | yield client 87 | 88 | 89 | @pytest.fixture 90 | def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams: 91 | return get_mock_connection_params() 92 | 93 | 94 | def mock_on_stop(expected_disconnect: bool) -> None: 95 | pass 96 | 97 | 98 | @pytest.fixture 99 | async def conn(connection_params: ConnectionParams) -> APIConnection: 100 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None) 101 | 102 | 103 | @pytest.fixture 104 | async def conn_with_password(connection_params: ConnectionParams) -> APIConnection: 105 | connection_params = replace(connection_params, password="password") 106 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None) 107 | 108 | 109 | @pytest.fixture 110 | async def noise_conn(connection_params: ConnectionParams) -> APIConnection: 111 | connection_params = replace( 112 | connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=" 113 | ) 114 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None) 115 | 116 | 117 | @pytest.fixture 118 | async def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection: 119 | connection_params = replace(connection_params, expected_name="test") 120 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None) 121 | 122 | 123 | @pytest.fixture() 124 | async def aiohappyeyeballs_start_connection(): 125 | with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: 126 | mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) 127 | mock_socket.type = socket.SOCK_STREAM 128 | mock_socket.fileno.return_value = 1 129 | mock_socket.getpeername.return_value = ("10.0.0.512", 323) 130 | func.return_value = mock_socket 131 | yield func 132 | 133 | 134 | @pytest_asyncio.fixture(name="plaintext_connect_task_no_login") 135 | async def plaintext_connect_task_no_login( 136 | conn: APIConnection, 137 | resolve_host, 138 | aiohappyeyeballs_start_connection, 139 | ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: 140 | loop = asyncio.get_running_loop() 141 | transport = MagicMock() 142 | connected = asyncio.Event() 143 | 144 | with patch.object( 145 | loop, 146 | "create_connection", 147 | side_effect=partial(_create_mock_transport_protocol, transport, connected), 148 | ): 149 | connect_task = asyncio.create_task(connect(conn, login=False)) 150 | await connected.wait() 151 | yield conn, transport, conn._frame_helper, connect_task 152 | conn.force_disconnect() 153 | 154 | 155 | @pytest_asyncio.fixture(name="plaintext_connect_task_expected_name") 156 | async def plaintext_connect_task_no_login_with_expected_name( 157 | conn_with_expected_name: APIConnection, 158 | resolve_host, 159 | aiohappyeyeballs_start_connection, 160 | ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: 161 | event_loop = asyncio.get_running_loop() 162 | transport = MagicMock() 163 | connected = asyncio.Event() 164 | 165 | with patch.object( 166 | event_loop, 167 | "create_connection", 168 | side_effect=partial(_create_mock_transport_protocol, transport, connected), 169 | ): 170 | connect_task = asyncio.create_task( 171 | connect(conn_with_expected_name, login=False) 172 | ) 173 | await connected.wait() 174 | yield ( 175 | conn_with_expected_name, 176 | transport, 177 | conn_with_expected_name._frame_helper, 178 | connect_task, 179 | ) 180 | conn_with_expected_name.force_disconnect() 181 | 182 | 183 | @pytest_asyncio.fixture(name="plaintext_connect_task_with_login") 184 | async def plaintext_connect_task_with_login( 185 | conn_with_password: APIConnection, 186 | resolve_host, 187 | aiohappyeyeballs_start_connection, 188 | ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: 189 | transport = MagicMock() 190 | connected = asyncio.Event() 191 | event_loop = asyncio.get_running_loop() 192 | 193 | with patch.object( 194 | event_loop, 195 | "create_connection", 196 | side_effect=partial(_create_mock_transport_protocol, transport, connected), 197 | ): 198 | connect_task = asyncio.create_task(connect(conn_with_password, login=True)) 199 | await connected.wait() 200 | yield ( 201 | conn_with_password, 202 | transport, 203 | conn_with_password._frame_helper, 204 | connect_task, 205 | ) 206 | conn_with_password.force_disconnect() 207 | 208 | 209 | @pytest_asyncio.fixture(name="api_client") 210 | async def api_client( 211 | resolve_host, aiohappyeyeballs_start_connection 212 | ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: 213 | event_loop = asyncio.get_running_loop() 214 | protocol: APIPlaintextFrameHelper | None = None 215 | transport = MagicMock() 216 | connected = asyncio.Event() 217 | client = APIClient( 218 | address="mydevice.local", 219 | port=6052, 220 | password=None, 221 | ) 222 | 223 | with ( 224 | patch.object( 225 | event_loop, 226 | "create_connection", 227 | side_effect=partial(_create_mock_transport_protocol, transport, connected), 228 | ), 229 | patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), 230 | ): 231 | connect_task = asyncio.create_task(connect_client(client, login=False)) 232 | await connected.wait() 233 | conn = client._connection 234 | protocol = conn._frame_helper 235 | send_plaintext_hello(protocol) 236 | await connect_task 237 | transport.reset_mock() 238 | yield client, conn, transport, protocol 239 | conn.force_disconnect() 240 | 241 | 242 | def get_scheduled_timer_handles( 243 | loop: asyncio.AbstractEventLoop, 244 | ) -> list[asyncio.TimerHandle]: 245 | """Return a list of scheduled TimerHandles.""" 246 | handles: list[asyncio.TimerHandle] = loop._scheduled # type: ignore[attr-defined] 247 | return handles 248 | 249 | 250 | @contextmanager 251 | def long_repr_strings() -> Generator[None]: 252 | """Increase reprlib maxstring and maxother to 300.""" 253 | arepr = reprlib.aRepr 254 | original_maxstring = arepr.maxstring 255 | original_maxother = arepr.maxother 256 | arepr.maxstring = 300 257 | arepr.maxother = 300 258 | try: 259 | yield 260 | finally: 261 | arepr.maxstring = original_maxstring 262 | arepr.maxother = original_maxother 263 | 264 | 265 | @pytest.fixture(autouse=True) 266 | def verify_no_lingering_tasks( 267 | event_loop: asyncio.AbstractEventLoop, 268 | ) -> Generator[None]: 269 | """Verify that all tasks are cleaned up.""" 270 | tasks_before = asyncio.all_tasks(event_loop) 271 | yield 272 | 273 | tasks = asyncio.all_tasks(event_loop) - tasks_before 274 | for task in tasks: 275 | pytest.fail(f"Task still running: {task!r}") 276 | task.cancel() 277 | if tasks: 278 | event_loop.run_until_complete(asyncio.wait(tasks)) 279 | 280 | for handle in get_scheduled_timer_handles(event_loop): 281 | if not handle.cancelled(): 282 | with long_repr_strings(): 283 | pytest.fail(f"Lingering timer after test {handle!r}") 284 | handle.cancel() 285 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO 4 | 5 | 6 | def test_order_and_no_missing_numbers_in_message_type_to_proto(): 7 | """Test that MESSAGE_TYPE_TO_PROTO has no missing numbers.""" 8 | for idx, (k, v) in enumerate(MESSAGE_TYPE_TO_PROTO.items()): 9 | assert idx + 1 == k 10 | -------------------------------------------------------------------------------- /tests/test_log_runner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from datetime import timedelta 5 | from functools import partial 6 | from unittest.mock import MagicMock, patch 7 | 8 | from google.protobuf import message 9 | import pytest 10 | 11 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper 12 | from aioesphomeapi.api_pb2 import ( 13 | DisconnectRequest, 14 | DisconnectResponse, 15 | SubscribeLogsResponse, # type: ignore 16 | ) 17 | from aioesphomeapi.client import APIClient 18 | from aioesphomeapi.connection import APIConnection 19 | from aioesphomeapi.core import APIConnectionError 20 | from aioesphomeapi.log_runner import async_run 21 | from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN 22 | 23 | from .common import ( 24 | Estr, 25 | async_fire_time_changed, 26 | generate_plaintext_packet, 27 | get_mock_async_zeroconf, 28 | mock_data_received, 29 | send_plaintext_connect_response, 30 | send_plaintext_hello, 31 | utcnow, 32 | ) 33 | 34 | 35 | async def test_log_runner( 36 | conn: APIConnection, 37 | aiohappyeyeballs_start_connection, 38 | ): 39 | """Test the log runner logic.""" 40 | loop = asyncio.get_running_loop() 41 | protocol: APIPlaintextFrameHelper | None = None 42 | transport = MagicMock() 43 | connected = asyncio.Event() 44 | 45 | class PatchableAPIClient(APIClient): 46 | pass 47 | 48 | async_zeroconf = get_mock_async_zeroconf() 49 | 50 | cli = PatchableAPIClient( 51 | address=Estr("127.0.0.1"), 52 | port=6052, 53 | password=None, 54 | noise_psk=None, 55 | expected_name=Estr("fake"), 56 | zeroconf_instance=async_zeroconf.zeroconf, 57 | ) 58 | messages = [] 59 | 60 | def on_log(msg: SubscribeLogsResponse) -> None: 61 | messages.append(msg) 62 | 63 | def _create_mock_transport_protocol(create_func, **kwargs): 64 | nonlocal protocol 65 | protocol = create_func() 66 | protocol.connection_made(transport) 67 | connected.set() 68 | return transport, protocol 69 | 70 | subscribed = asyncio.Event() 71 | original_subscribe_logs = cli.subscribe_logs 72 | 73 | def _wait_subscribe_cli(*args, **kwargs): 74 | original_subscribe_logs(*args, **kwargs) 75 | subscribed.set() 76 | 77 | with ( 78 | patch.object( 79 | loop, "create_connection", side_effect=_create_mock_transport_protocol 80 | ), 81 | patch.object(cli, "subscribe_logs", _wait_subscribe_cli), 82 | ): 83 | stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) 84 | await connected.wait() 85 | protocol = cli._connection._frame_helper 86 | send_plaintext_hello(protocol) 87 | send_plaintext_connect_response(protocol, False) 88 | await subscribed.wait() 89 | 90 | response: message.Message = SubscribeLogsResponse() 91 | response.message = b"Hello world" 92 | mock_data_received(protocol, generate_plaintext_packet(response)) 93 | assert len(messages) == 1 94 | assert messages[0].message == b"Hello world" 95 | stop_task = asyncio.create_task(stop()) 96 | await asyncio.sleep(0) 97 | disconnect_response = DisconnectResponse() 98 | mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) 99 | await stop_task 100 | 101 | 102 | async def test_log_runner_reconnects_on_disconnect( 103 | conn: APIConnection, 104 | caplog: pytest.LogCaptureFixture, 105 | aiohappyeyeballs_start_connection, 106 | ) -> None: 107 | """Test the log runner reconnects on disconnect.""" 108 | loop = asyncio.get_running_loop() 109 | protocol: APIPlaintextFrameHelper | None = None 110 | transport = MagicMock() 111 | connected = asyncio.Event() 112 | 113 | class PatchableAPIClient(APIClient): 114 | pass 115 | 116 | async_zeroconf = get_mock_async_zeroconf() 117 | 118 | cli = PatchableAPIClient( 119 | address=Estr("127.0.0.1"), 120 | port=6052, 121 | password=None, 122 | noise_psk=None, 123 | expected_name=Estr("fake"), 124 | zeroconf_instance=async_zeroconf.zeroconf, 125 | ) 126 | messages = [] 127 | 128 | def on_log(msg: SubscribeLogsResponse) -> None: 129 | messages.append(msg) 130 | 131 | def _create_mock_transport_protocol(create_func, **kwargs): 132 | nonlocal protocol 133 | protocol = create_func() 134 | protocol.connection_made(transport) 135 | connected.set() 136 | return transport, protocol 137 | 138 | subscribed = asyncio.Event() 139 | original_subscribe_logs = cli.subscribe_logs 140 | 141 | def _wait_subscribe_cli(*args, **kwargs): 142 | original_subscribe_logs(*args, **kwargs) 143 | subscribed.set() 144 | 145 | with ( 146 | patch.object( 147 | loop, "create_connection", side_effect=_create_mock_transport_protocol 148 | ), 149 | patch.object(cli, "subscribe_logs", _wait_subscribe_cli), 150 | ): 151 | stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) 152 | await connected.wait() 153 | protocol = cli._connection._frame_helper 154 | send_plaintext_hello(protocol) 155 | send_plaintext_connect_response(protocol, False) 156 | await subscribed.wait() 157 | 158 | response: message.Message = SubscribeLogsResponse() 159 | response.message = b"Hello world" 160 | mock_data_received(protocol, generate_plaintext_packet(response)) 161 | assert len(messages) == 1 162 | assert messages[0].message == b"Hello world" 163 | 164 | with patch.object(cli, "start_resolve_host") as mock_start_resolve_host: 165 | response: message.Message = DisconnectRequest() 166 | mock_data_received(protocol, generate_plaintext_packet(response)) 167 | 168 | await asyncio.sleep(0) 169 | assert cli._connection is None 170 | async_fire_time_changed( 171 | utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN) 172 | ) 173 | await asyncio.sleep(0) 174 | 175 | assert "Disconnected from API" in caplog.text 176 | assert mock_start_resolve_host.called 177 | 178 | await stop() 179 | 180 | 181 | async def test_log_runner_reconnects_on_subscribe_failure( 182 | conn: APIConnection, 183 | caplog: pytest.LogCaptureFixture, 184 | aiohappyeyeballs_start_connection, 185 | ) -> None: 186 | """Test the log runner reconnects on subscribe failure.""" 187 | loop = asyncio.get_running_loop() 188 | protocol: APIPlaintextFrameHelper | None = None 189 | transport = MagicMock() 190 | connected = asyncio.Event() 191 | 192 | class PatchableAPIClient(APIClient): 193 | pass 194 | 195 | async_zeroconf = get_mock_async_zeroconf() 196 | 197 | cli = PatchableAPIClient( 198 | address=Estr("127.0.0.1"), 199 | port=6052, 200 | password=None, 201 | noise_psk=None, 202 | expected_name=Estr("fake"), 203 | zeroconf_instance=async_zeroconf.zeroconf, 204 | ) 205 | messages = [] 206 | 207 | def on_log(msg: SubscribeLogsResponse) -> None: 208 | messages.append(msg) 209 | 210 | def _create_mock_transport_protocol(create_func, **kwargs): 211 | nonlocal protocol 212 | protocol = create_func() 213 | protocol.connection_made(transport) 214 | connected.set() 215 | return transport, protocol 216 | 217 | subscribed = asyncio.Event() 218 | 219 | def _wait_and_fail_subscribe_cli(*args, **kwargs): 220 | subscribed.set() 221 | raise APIConnectionError("subscribed force to fail") 222 | 223 | with ( 224 | patch.object(cli, "disconnect", partial(cli.disconnect, force=True)), 225 | patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli), 226 | ): 227 | with patch.object( 228 | loop, "create_connection", side_effect=_create_mock_transport_protocol 229 | ): 230 | stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) 231 | await connected.wait() 232 | protocol = cli._connection._frame_helper 233 | send_plaintext_hello(protocol) 234 | send_plaintext_connect_response(protocol, False) 235 | 236 | await subscribed.wait() 237 | 238 | assert cli._connection is None 239 | 240 | with ( 241 | patch.object( 242 | loop, "create_connection", side_effect=_create_mock_transport_protocol 243 | ), 244 | patch.object(cli, "subscribe_logs"), 245 | ): 246 | connected.clear() 247 | await asyncio.sleep(0) 248 | async_fire_time_changed( 249 | utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN) 250 | ) 251 | await asyncio.sleep(0) 252 | 253 | stop_task = asyncio.create_task(stop()) 254 | await asyncio.sleep(0) 255 | 256 | send_plaintext_connect_response(protocol, False) 257 | send_plaintext_hello(protocol) 258 | 259 | disconnect_response = DisconnectResponse() 260 | mock_data_received(protocol, generate_plaintext_packet(disconnect_response)) 261 | 262 | await stop_task 263 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import math 3 | import sys 4 | 5 | import pytest 6 | 7 | from aioesphomeapi import util 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "input, output", 12 | [ 13 | (0, 0), 14 | (float("inf"), float("inf")), 15 | (float("-inf"), float("-inf")), 16 | (0.1, 0.1), 17 | (-0.0, -0.0), 18 | (0.10000000149011612, 0.1), 19 | (1, 1), 20 | (-1, -1), 21 | (-0.10000000149011612, -0.1), 22 | (-152198557936981706463557226105667584, -152198600000000000000000000000000000), 23 | (-0.0030539485160261, -0.003053949), 24 | (0.5, 0.5), 25 | (0.0000000000000019, 0.0000000000000019), 26 | ], 27 | ) 28 | def test_fix_float_single_double_conversion(input, output): 29 | assert util.fix_float_single_double_conversion(input) == output 30 | 31 | 32 | def test_fix_float_single_double_conversion_nan(): 33 | assert math.isnan(util.fix_float_single_double_conversion(float("nan"))) 34 | 35 | 36 | @pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+") 37 | async def test_create_eager_task_312() -> None: 38 | """Test create_eager_task schedules a task eagerly in the event loop. 39 | 40 | For Python 3.12+, the task is scheduled eagerly in the event loop. 41 | """ 42 | events = [] 43 | 44 | async def _normal_task(): 45 | events.append("normal") 46 | 47 | async def _eager_task(): 48 | events.append("eager") 49 | 50 | task1 = util.create_eager_task(_eager_task()) 51 | task2 = asyncio.create_task(_normal_task()) 52 | 53 | assert events == ["eager"] 54 | 55 | await asyncio.sleep(0) 56 | assert events == ["eager", "normal"] 57 | await task1 58 | await task2 59 | 60 | 61 | @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12") 62 | async def test_create_eager_task_pre_312() -> None: 63 | """Test create_eager_task schedules a task in the event loop. 64 | 65 | For older python versions, the task is scheduled normally. 66 | """ 67 | events = [] 68 | 69 | async def _normal_task(): 70 | events.append("normal") 71 | 72 | async def _eager_task(): 73 | events.append("eager") 74 | 75 | task1 = util.create_eager_task(_eager_task()) 76 | task2 = asyncio.create_task(_normal_task()) 77 | 78 | assert events == [] 79 | 80 | await asyncio.sleep(0) 81 | assert events == ["eager", "normal"] 82 | await task1 83 | await task2 84 | -------------------------------------------------------------------------------- /tests/test_zeroconf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | from zeroconf.asyncio import AsyncZeroconf 7 | 8 | from aioesphomeapi.zeroconf import ZeroconfManager 9 | 10 | from .common import get_mock_async_zeroconf 11 | 12 | 13 | async def test_does_not_closed_passed_in_async_instance(async_zeroconf: AsyncZeroconf): 14 | """Test that the passed in instance is not closed.""" 15 | manager = ZeroconfManager() 16 | manager.set_instance(async_zeroconf) 17 | await manager.async_close() 18 | assert async_zeroconf.async_close.call_count == 0 19 | 20 | 21 | async def test_does_not_closed_passed_in_sync_instance(async_zeroconf: AsyncZeroconf): 22 | """Test that the passed in instance is not closed.""" 23 | manager = ZeroconfManager() 24 | manager.set_instance(async_zeroconf.zeroconf) 25 | await manager.async_close() 26 | assert async_zeroconf.async_close.call_count == 0 27 | 28 | 29 | async def test_closes_created_instance(async_zeroconf: AsyncZeroconf): 30 | """Test that the created instance is closed.""" 31 | with patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf): 32 | manager = ZeroconfManager() 33 | assert manager.get_async_zeroconf() is async_zeroconf 34 | await manager.async_close() 35 | assert async_zeroconf.async_close.call_count == 1 36 | 37 | 38 | async def test_runtime_error_multiple_instances(async_zeroconf: AsyncZeroconf): 39 | """Test runtime error is raised on multiple instances.""" 40 | manager = ZeroconfManager(async_zeroconf) 41 | new_instance = get_mock_async_zeroconf() 42 | with pytest.raises(RuntimeError): 43 | manager.set_instance(new_instance) 44 | manager.set_instance(async_zeroconf) 45 | manager.set_instance(async_zeroconf.zeroconf) 46 | manager.set_instance(async_zeroconf) 47 | await manager.async_close() 48 | assert async_zeroconf.async_close.call_count == 0 49 | --------------------------------------------------------------------------------