├── .coveragerc
├── .devcontainer
└── devcontainer.json
├── .dockerignore
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── labeler.yml
├── release-drafter.yml
└── workflows
│ ├── ci.yml
│ ├── docker.yml
│ ├── labeler.yml
│ ├── matchers
│ ├── flake8.json
│ ├── isort.json
│ ├── mypy.json
│ ├── pylint.json
│ └── pytest.json
│ ├── release-drafter.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
└── tasks.json
├── Dockerfile
├── LICENSE
├── MAINTAINERS.md
├── MANIFEST.in
├── README.rst
├── aioesphomeapi
├── __init__.py
├── _frame_helper
│ ├── __init__.py
│ ├── base.pxd
│ ├── base.py
│ ├── noise.pxd
│ ├── noise.py
│ ├── noise_encryption.pxd
│ ├── noise_encryption.py
│ ├── pack.pxd
│ ├── pack.pyx
│ ├── packets.pxd
│ ├── packets.py
│ ├── plain_text.pxd
│ └── plain_text.py
├── api.proto
├── api_options.proto
├── api_options_pb2.py
├── api_pb2.py
├── ble_defs.py
├── client.py
├── client_base.pxd
├── client_base.py
├── connection.pxd
├── connection.py
├── core.py
├── discover.py
├── host_resolver.py
├── log_parser.py
├── log_reader.py
├── log_runner.py
├── model.py
├── model_conversions.py
├── py.typed
├── reconnect_logic.py
├── util.py
└── zeroconf.py
├── bench
├── raw_ble_plain_text.py
└── raw_ble_plain_text_with_callback.py
├── mypy.ini
├── pyproject.toml
├── requirements
├── base.txt
└── test.txt
├── script
├── gen-protoc
└── lint
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── benchmarks
├── __init__.py
├── conftest.py
├── test_bluetooth.py
├── test_noise.py
└── test_requests.py
├── common.py
├── conftest.py
├── test__frame_helper.py
├── test_client.py
├── test_connection.py
├── test_core.py
├── test_host_resolver.py
├── test_log_parser.py
├── test_log_runner.py
├── test_model.py
├── test_reconnect_logic.py
├── test_util.py
└── test_zeroconf.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = aioesphomeapi
3 |
4 | omit =
5 | aioesphomeapi/api_options_pb2.py
6 | aioesphomeapi/api_pb2.py
7 | aioesphomeapi/log_reader.py
8 | aioesphomeapi/discover.py
9 | bench/*.py
10 |
11 | [report]
12 | # Regexes for lines to exclude from consideration
13 | exclude_lines =
14 | # Have to re-enable the standard pragma
15 | pragma: no cover
16 |
17 | # Don't complain about missing debug-only code:
18 | def __repr__
19 |
20 | # Don't complain if tests don't hit defensive assertion code:
21 | raise AssertionError
22 | raise NotImplementedError
23 | raise exceptions.NotSupportedError
24 |
25 | # TYPE_CHECKING and @overload blocks are never executed during pytest run
26 | if TYPE_CHECKING:
27 | @overload
28 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "ESPHome API Client Dev",
3 | "image": "ghcr.io/esphome/aioesphomeapi-proto-builder:latest",
4 | "features": {
5 | "ghcr.io/devcontainers/features/github-cli:1": {}
6 | },
7 | "customizations": {
8 | "vscode": {
9 | "extensions": [
10 | "ms-python.python",
11 | "visualstudioexptteam.vscodeintellicode",
12 | // yaml
13 | "redhat.vscode-yaml",
14 | // editorconfig
15 | "editorconfig.editorconfig",
16 | // protobuf
17 | "pbkit.vscode-pbkit"
18 | ],
19 | "settings": {
20 | "python.languageServer": "Pylance",
21 | "python.pythonPath": "/usr/bin/python3",
22 | "python.formatting.provider": "black",
23 | "editor.formatOnPaste": false,
24 | "editor.formatOnSave": true,
25 | "editor.formatOnType": true,
26 | "files.trimTrailingWhitespace": true,
27 | "terminal.integrated.defaultProfile.linux": "bash",
28 | "files.exclude": {
29 | "**/.git": true,
30 | "**/.DS_Store": true,
31 | "**/*.pyc": {
32 | "when": "$(basename).py"
33 | },
34 | "**/__pycache__": true
35 | },
36 | "files.associations": {
37 | "**/.vscode/*.json": "jsonc"
38 | }
39 | }
40 | }
41 | },
42 | "postCreateCommand": "pip3 install -e ."
43 | }
44 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Hide sublime text stuff
2 | *.sublime-project
3 | *.sublime-workspace
4 |
5 | # Hide some OS X stuff
6 | .DS_Store
7 | .AppleDouble
8 | .LSOverride
9 | Icon
10 |
11 | # Thumbnails
12 | ._*
13 |
14 | # IntelliJ IDEA
15 | .idea
16 | *.iml
17 |
18 | # pytest
19 | .pytest_cache
20 | .cache
21 |
22 | # GITHUB Proposed Python stuff:
23 | *.py[cod]
24 |
25 | # C extensions
26 | *.so
27 |
28 | # Packages
29 | *.egg
30 | *.egg-info
31 | dist
32 | build
33 | eggs
34 | .eggs
35 | parts
36 | bin
37 | var
38 | sdist
39 | develop-eggs
40 | .installed.cfg
41 | lib
42 | lib64
43 |
44 | # Logs
45 | *.log
46 | pip-log.txt
47 |
48 | # Unit test / coverage reports
49 | .coverage
50 | .tox
51 | nosetests.xml
52 | htmlcov/
53 |
54 | # Translations
55 | *.mo
56 |
57 | # Mr Developer
58 | .mr.developer.cfg
59 | .project
60 | .pydevproject
61 |
62 | .python-version
63 |
64 | # emacs auto backups
65 | *~
66 | *#
67 | *.orig
68 |
69 | # venv stuff
70 | pyvenv.cfg
71 | pip-selfcheck.json
72 | venv
73 | .venv
74 | Pipfile*
75 | share/*
76 |
77 | # vimmy stuff
78 | *.swp
79 | *.swo
80 |
81 | ctags.tmp
82 |
83 | # Visual Studio Code
84 | .vscode
85 |
86 | # Built docs
87 | docs/build
88 |
89 | # Windows Explorer
90 | desktop.ini
91 | /.vs/*
92 |
93 | # mypy
94 | /.mypy_cache/*
95 |
96 | .git*
97 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # What does this implement/fix?
2 |
3 |
4 |
5 | ## Types of changes
6 |
7 |
14 |
15 | - [ ] Bugfix (non-breaking change which fixes an issue)
16 | - [ ] New feature (non-breaking change which adds functionality)
17 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
18 | - [ ] Code quality improvements to existing code or addition of tests
19 | - [ ] Other
20 |
21 | **Related issue or feature (if applicable):**
22 |
23 | - fixes
24 |
25 | **Pull request in [esphome](https://github.com/esphome/esphome) (if applicable):**
26 |
27 | - esphome/esphome#
28 |
29 | ## Checklist:
30 | - [ ] The code change is tested and works locally.
31 | - [ ] If api.proto was modified, a linked pull request has been made to [esphome](https://github.com/esphome/esphome) with the same changes.
32 | - [ ] Tests have been added to verify that the new code works (under `tests/` folder).
33 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # dependabot config
2 | version: 2
3 | updates:
4 | - package-ecosystem: "pip"
5 | directory: "/"
6 | schedule:
7 | interval: "daily"
8 | - package-ecosystem: "github-actions"
9 | directory: "/"
10 | schedule:
11 | interval: "daily"
12 |
--------------------------------------------------------------------------------
/.github/labeler.yml:
--------------------------------------------------------------------------------
1 | dependencies:
2 | - changed-files:
3 | - any-glob-to-any-file:
4 | - '.pre-commit-config.yaml'
5 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name-template: "$RESOLVED_VERSION"
2 | tag-template: "v$RESOLVED_VERSION"
3 | categories:
4 | - title: "Breaking Changes"
5 | label: "breaking-change"
6 | - title: "Dependencies"
7 | collapse-after: 1
8 | labels:
9 | - "dependencies"
10 |
11 | version-resolver:
12 | major:
13 | labels:
14 | - "major"
15 | - "breaking-change"
16 | minor:
17 | labels:
18 | - "minor"
19 | - "new-feature"
20 | patch:
21 | labels:
22 | - "bugfix"
23 | - "dependencies"
24 | - "documentation"
25 | - "enhancement"
26 | default: patch
27 |
28 | template: |
29 | ## What's Changed
30 |
31 | $CHANGES
32 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 |
8 | permissions:
9 | contents: read
10 |
11 | concurrency:
12 | # yamllint disable-line rule:line-length
13 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
14 | cancel-in-progress: true
15 |
16 | jobs:
17 | ci:
18 | name: ${{ matrix.name }} py ${{ matrix.python-version }} on ${{ matrix.os }} (${{ matrix.extension }})
19 | runs-on: ${{ matrix.os }}
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | python-version:
24 | - "3.9"
25 | - "3.10"
26 | - "3.11"
27 | - "3.12"
28 | - "3.13"
29 | os:
30 | - ubuntu-latest
31 | - windows-latest
32 | - macos-latest
33 | extension:
34 | - "skip_cython"
35 | - "use_cython"
36 | exclude:
37 | - python-version: "3.9"
38 | os: windows-latest
39 | - python-version: "3.10"
40 | os: windows-latest
41 | - python-version: "3.11"
42 | os: windows-latest
43 | - python-version: "3.13"
44 | os: windows-latest
45 | - python-version: "3.9"
46 | os: macos-latest
47 | - python-version: "3.10"
48 | os: macos-latest
49 | - python-version: "3.11"
50 | os: macos-latest
51 | - python-version: "3.13"
52 | os: macos-lates
53 | - extension: "use_cython"
54 | os: windows-latest
55 | - extension: "use_cython"
56 | os: macos-latest
57 | steps:
58 | - uses: actions/checkout@v4
59 | - name: Set up Python
60 | uses: actions/setup-python@v5
61 | id: python
62 | with:
63 | python-version: ${{ matrix.python-version }}
64 | allow-prereleases: true
65 |
66 | - name: Get pip cache dir
67 | id: pip-cache
68 | shell: bash
69 | run: |
70 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
71 | - name: Restore PIP cache
72 | uses: actions/cache@v4
73 | with:
74 | path: ${{ steps.pip-cache.outputs.dir }}
75 | key: pip-${{ steps.python.outputs.python-version }}-${{ matrix.extension }}-${{ hashFiles('requirements/base.txt', 'requirements/test.txt') }}
76 | restore-keys: |
77 | pip-${{ steps.python.outputs.python-version }}-${{ matrix.extension }}-
78 | - name: Set up Python environment (no cython)
79 | if: ${{ matrix.extension == 'skip_cython' }}
80 | env:
81 | SKIP_CYTHON: 1
82 | shell: bash
83 | run: |
84 | pip3 install -r requirements/base.txt -r requirements/test.txt
85 | pip3 install -e .
86 | - name: Set up Python environment (cython)
87 | if: ${{ matrix.extension == 'use_cython' }}
88 | env:
89 | REQUIRE_CYTHON: 1
90 | shell: bash
91 | run: |
92 | pip3 install -r requirements/base.txt -r requirements/test.txt
93 | pip3 install -e .
94 | - name: Register problem matchers
95 | shell: bash
96 | run: |
97 | echo "::add-matcher::.github/workflows/matchers/flake8.json"
98 | echo "::add-matcher::.github/workflows/matchers/pylint.json"
99 | echo "::add-matcher::.github/workflows/matchers/mypy.json"
100 | echo "::add-matcher::.github/workflows/matchers/pytest.json"
101 |
102 | - run: flake8 aioesphomeapi
103 | name: Lint with flake8
104 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
105 | - run: ruff format --check aioesphomeapi tests
106 | name: Check ruff formatting
107 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
108 | - run: ruff check aioesphomeapi tests
109 | name: Check with ruff
110 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
111 | - run: mypy aioesphomeapi
112 | name: Check typing with mypy
113 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
114 | - run: |
115 | docker run \
116 | -v "$PWD":/aioesphomeapi \
117 | -u "$(id -u):$(id -g)" \
118 | ghcr.io/esphome/aioesphomeapi-proto-builder:latest
119 | if ! git diff --quiet; then
120 | echo "You have altered the generated proto files but they do not match what is expected."
121 | echo "Please run the following to update the generated files:"
122 | echo 'docker run -v "$PWD":/aioesphomeapi ghcr.io/esphome/aioesphomeapi-proto-builder:latest'
123 | exit 1
124 | fi
125 | name: Check protobuf files match
126 | if: ${{ matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
127 | - name: Show changes
128 | run: git diff
129 | if: ${{ failure() && matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
130 | - name: Archive artifacts
131 | uses: actions/upload-artifact@v4
132 | with:
133 | name: genrated-proto-files
134 | path: aioesphomeapi/*pb2.py
135 | if: ${{ failure() && matrix.python-version == '3.12' && matrix.extension == 'skip_cython' && matrix.os == 'ubuntu-latest' }}
136 | - run: pytest -vv --cov=aioesphomeapi --cov-report=xml --timeout=4 --tb=native tests
137 | name: Run tests with pytest
138 | - name: Upload coverage to Codecov
139 | uses: codecov/codecov-action@v5
140 | env:
141 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
142 |
143 | benchmarks:
144 | name: Run benchmarks
145 | runs-on: ubuntu-22.04
146 | steps:
147 | - uses: actions/checkout@v4
148 | - uses: actions/setup-python@v5
149 | with:
150 | python-version: "3.13"
151 | cache: 'pip' # caching pip dependencies
152 | - name: Set up Python environment (cython)
153 | env:
154 | REQUIRE_CYTHON: 1
155 | shell: bash
156 | run: |
157 | pip3 install -r requirements/base.txt -r requirements/test.txt
158 | pip3 install -e .
159 | - name: Run benchmarks
160 | uses: CodSpeedHQ/action@v3
161 | with:
162 | token: ${{ secrets.CODSPEED_TOKEN }}
163 | run: pytest tests/benchmarks --codspeed
164 |
--------------------------------------------------------------------------------
/.github/workflows/docker.yml:
--------------------------------------------------------------------------------
1 | name: Build docker image on changes
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - Dockerfile
9 | - requirements/test.txt
10 | - .github/workflows/docker.yml
11 |
12 | permissions:
13 | contents: read
14 | packages: write
15 |
16 | concurrency:
17 | # yamllint disable-line rule:line-length
18 | group: ${{ github.workflow }}-${{ github.ref }}
19 | cancel-in-progress: true
20 |
21 | jobs:
22 | build-image:
23 | runs-on: ubuntu-latest
24 | name: Build and push Docker image
25 | steps:
26 | -
27 | name: Checkout source code
28 | uses: actions/checkout@v4
29 | -
30 | name: Log in to docker hub
31 | uses: docker/login-action@v3.4.0
32 | with:
33 | username: ${{ secrets.DOCKER_USER }}
34 | password: ${{ secrets.DOCKER_PASSWORD }}
35 | -
36 | name: Log in to the GitHub container registry
37 | uses: docker/login-action@v3.4.0
38 | with:
39 | registry: ghcr.io
40 | username: ${{ github.actor }}
41 | password: ${{ secrets.GITHUB_TOKEN }}
42 | -
43 | name: Set up QEMU
44 | uses: docker/setup-qemu-action@v3.4.0
45 | -
46 | name: Set up Docker Buildx
47 | uses: docker/setup-buildx-action@v3.10.0
48 | -
49 | name: Build and Push
50 | uses: docker/build-push-action@v6.18.0
51 | with:
52 | context: .
53 | tags: |
54 | ghcr.io/esphome/aioesphomeapi-proto-builder:latest
55 | esphome/aioesphomeapi-proto-builder:latest
56 | push: true
57 | pull: true
58 | cache-to: type=inline
59 | cache-from: ghcr.io/esphome/aioesphomeapi-proto-builder:latest
60 | platforms: linux/amd64,linux/arm64
61 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: Add labels
2 |
3 | on:
4 | - pull_request_target
5 |
6 | jobs:
7 | labeler:
8 | permissions:
9 | contents: read
10 | pull-requests: write
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/labeler@v5.0.0
14 |
--------------------------------------------------------------------------------
/.github/workflows/matchers/flake8.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "flake8-error",
5 | "severity": "error",
6 | "pattern": [
7 | {
8 | "regexp": "^(.*):(\\d+):(\\d+):\\s([EF]\\d{3}\\s.*)$",
9 | "file": 1,
10 | "line": 2,
11 | "column": 3,
12 | "message": 4
13 | }
14 | ]
15 | },
16 | {
17 | "owner": "flake8-warning",
18 | "severity": "warning",
19 | "pattern": [
20 | {
21 | "regexp": "^(.*):(\\d+):(\\d+):\\s([CDNW]\\d{3}\\s.*)$",
22 | "file": 1,
23 | "line": 2,
24 | "column": 3,
25 | "message": 4
26 | }
27 | ]
28 | }
29 | ]
30 | }
31 |
--------------------------------------------------------------------------------
/.github/workflows/matchers/isort.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "isort",
5 | "pattern": [
6 | {
7 | "regexp": "^ERROR:\\s+(.+)\\s+(.+)$",
8 | "file": 1,
9 | "message": 2
10 | }
11 | ]
12 | }
13 | ]
14 | }
15 |
--------------------------------------------------------------------------------
/.github/workflows/matchers/mypy.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "mypy",
5 | "pattern": [
6 | {
7 | "regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$",
8 | "file": 1,
9 | "line": 2,
10 | "severity": 3,
11 | "message": 4
12 | }
13 | ]
14 | }
15 | ]
16 | }
17 |
--------------------------------------------------------------------------------
/.github/workflows/matchers/pylint.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "pylint-error",
5 | "severity": "error",
6 | "pattern": [
7 | {
8 | "regexp": "^(.+):(\\d+):(\\d+):\\s(([EF]\\d{4}):\\s.+)$",
9 | "file": 1,
10 | "line": 2,
11 | "column": 3,
12 | "message": 4,
13 | "code": 5
14 | }
15 | ]
16 | },
17 | {
18 | "owner": "pylint-warning",
19 | "severity": "warning",
20 | "pattern": [
21 | {
22 | "regexp": "^(.+):(\\d+):(\\d+):\\s(([CRW]\\d{4}):\\s.+)$",
23 | "file": 1,
24 | "line": 2,
25 | "column": 3,
26 | "message": 4,
27 | "code": 5
28 | }
29 | ]
30 | }
31 | ]
32 | }
33 |
--------------------------------------------------------------------------------
/.github/workflows/matchers/pytest.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "pytest",
5 | "fileLocation": "absolute",
6 | "pattern": [
7 | {
8 | "regexp": "^\\s+File \"(.*)\", line (\\d+), in (.*)$",
9 | "file": 1,
10 | "line": 2
11 | },
12 | {
13 | "regexp": "^\\s+(.*)$",
14 | "message": 1
15 | }
16 | ]
17 | }
18 | ]
19 | }
20 |
--------------------------------------------------------------------------------
/.github/workflows/release-drafter.yml:
--------------------------------------------------------------------------------
1 |
2 | name: Release Drafter
3 |
4 | on:
5 | push:
6 | branches:
7 | - main
8 |
9 | permissions:
10 | contents: write
11 | pull-requests: read
12 |
13 | jobs:
14 | update_release_draft:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: release-drafter/release-drafter@v6
18 | id: release-draft
19 | env:
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 |
22 | # Update version number in setup.py
23 | - uses: actions/checkout@v4
24 | - name: Calculate version
25 | id: version
26 | run: |
27 | tag="${{ steps.release-draft.outputs.tag_name }}"
28 | echo "version=${tag:1}" >> $GITHUB_OUTPUT
29 |
30 | - name: Set setup.py version
31 | run: |
32 | sed -i "s/VERSION = .*/VERSION = \"${{ steps.version.outputs.version }}\"/g" setup.py
33 | # github actions email from here: https://github.community/t/github-actions-bot-email-address/17204
34 | - name: Commit changes
35 | run: |
36 | if ! git diff --quiet; then
37 | git config --global user.name "github-actions[bot]"
38 | git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
39 | git commit -am "Bump version to ${{ steps.version.outputs.version }}"
40 | git push
41 | fi
42 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Publish Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | build_wheels:
12 | name: Wheels for ${{ matrix.os }} (${{ matrix.musl == 'musllinux' && 'musllinux' || 'manylinux' }}) ${{ matrix.qemu }} ${{ matrix.pyver }}
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | os:
17 | [
18 | ubuntu-24.04-arm,
19 | ubuntu-latest,
20 | macos-13,
21 | macos-latest,
22 | ]
23 | qemu: [""]
24 | musl: [""]
25 | pyver: [""]
26 | include:
27 | - os: ubuntu-latest
28 | musl: "musllinux"
29 | - os: ubuntu-24.04-arm
30 | musl: "musllinux"
31 | # qemu is slow, make a single
32 | # runner per Python version
33 | - os: ubuntu-latest
34 | qemu: armv7l
35 | musl: "musllinux"
36 | pyver: cp39
37 | - os: ubuntu-latest
38 | qemu: armv7l
39 | musl: "musllinux"
40 | pyver: cp310
41 | - os: ubuntu-latest
42 | qemu: armv7l
43 | musl: "musllinux"
44 | pyver: cp311
45 | - os: ubuntu-latest
46 | qemu: armv7l
47 | musl: "musllinux"
48 | pyver: cp312
49 | - os: ubuntu-latest
50 | qemu: armv7l
51 | musl: "musllinux"
52 | pyver: cp313
53 | # qemu is slow, make a single
54 | # runner per Python version
55 | - os: ubuntu-latest
56 | qemu: armv7l
57 | musl: ""
58 | pyver: cp39
59 | - os: ubuntu-latest
60 | qemu: armv7l
61 | musl: ""
62 | pyver: cp310
63 | - os: ubuntu-latest
64 | qemu: armv7l
65 | musl: ""
66 | pyver: cp311
67 | - os: ubuntu-latest
68 | qemu: armv7l
69 | musl: ""
70 | pyver: cp312
71 | - os: ubuntu-latest
72 | qemu: armv7l
73 | musl: ""
74 | pyver: cp313
75 | steps:
76 | - uses: actions/checkout@v4
77 | with:
78 | fetch-depth: 0
79 | # Used to host cibuildwheel
80 | - name: Set up Python
81 | uses: actions/setup-python@v5
82 | with:
83 | python-version: "3.12"
84 | - name: Set up QEMU
85 | if: ${{ matrix.qemu }}
86 | uses: docker/setup-qemu-action@v3
87 | with:
88 | platforms: all
89 | # This should be temporary
90 | # xref https://github.com/docker/setup-qemu-action/issues/188
91 | # xref https://github.com/tonistiigi/binfmt/issues/215
92 | image: tonistiigi/binfmt:qemu-v8.1.5
93 | id: qemu
94 | - name: Prepare emulation
95 | if: ${{ matrix.qemu }}
96 | run: |
97 | if [[ -n "${{ matrix.qemu }}" ]]; then
98 | # Build emulated architectures only if QEMU is set,
99 | # use default "auto" otherwise
100 | echo "CIBW_ARCHS_LINUX=${{ matrix.qemu }}" >> $GITHUB_ENV
101 | fi
102 | - name: Limit to a specific Python version on slow QEMU
103 | if: ${{ matrix.pyver }}
104 | run: |
105 | if [[ -n "${{ matrix.pyver }}" ]]; then
106 | echo "CIBW_BUILD=${{ matrix.pyver }}*" >> $GITHUB_ENV
107 | fi
108 | - name: Build wheels
109 | uses: pypa/cibuildwheel@v2.23.3
110 | env:
111 | CIBW_SKIP: cp36-* cp37-* cp38-* pp* ${{ matrix.musl == 'musllinux' && '*manylinux*' || '*musllinux*' }}
112 | CIBW_BEFORE_ALL_LINUX: apt-get install -y gcc || yum install -y gcc || apk add gcc
113 | REQUIRE_CYTHON: 1
114 |
115 | - uses: actions/upload-artifact@v4
116 | with:
117 | name: wheels-${{ matrix.os }}-${{ matrix.musl }}-${{ matrix.pyver }}-${{ matrix.qemu }}
118 | path: ./wheelhouse/*.whl
119 |
120 | build_sdist:
121 | name: Build source distribution
122 | runs-on: ubuntu-latest
123 | steps:
124 | - uses: actions/checkout@v4
125 |
126 | - name: Build sdist
127 | run: pipx run build --sdist
128 |
129 | - uses: actions/upload-artifact@v4
130 | with:
131 | name: sdist
132 | path: dist/*.tar.gz
133 |
134 | upload_pypi:
135 | needs: [build_wheels, build_sdist]
136 | runs-on: ubuntu-latest
137 | environment: pypi
138 | permissions:
139 | id-token: write
140 | if: github.event_name == 'release' && github.event.action == 'published'
141 | steps:
142 | - name: Download artifacts
143 | uses: actions/download-artifact@v4
144 | with:
145 | path: dist
146 | merge-multiple: true
147 |
148 | - name: Publish to PyPI
149 | uses: pypa/gh-action-pypi-publish@v1.12.4
150 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Hide sublime text stuff
2 | *.sublime-project
3 | *.sublime-workspace
4 |
5 | # Hide some OS X stuff
6 | .DS_Store
7 | .AppleDouble
8 | .LSOverride
9 | Icon
10 |
11 | # Thumbnails
12 | ._*
13 |
14 | # IntelliJ IDEA
15 | .idea
16 | *.iml
17 |
18 | # pytest
19 | .pytest_cache
20 | .cache
21 |
22 | # GITHUB Proposed Python stuff:
23 | *.py[cod]
24 |
25 | # C extensions
26 | *.so
27 |
28 | # Packages
29 | *.egg
30 | *.egg-info
31 | dist
32 | build
33 | eggs
34 | .eggs
35 | parts
36 | bin
37 | var
38 | sdist
39 | develop-eggs
40 | .installed.cfg
41 | lib
42 | lib64
43 |
44 | # Logs
45 | *.log
46 | pip-log.txt
47 |
48 | # Unit test / coverage reports
49 | .coverage
50 | .tox
51 | nosetests.xml
52 | htmlcov/
53 |
54 | # Translations
55 | *.mo
56 |
57 | # Mr Developer
58 | .mr.developer.cfg
59 | .project
60 | .pydevproject
61 |
62 | .python-version
63 |
64 | # emacs auto backups
65 | *~
66 | *#
67 | *.orig
68 |
69 | # venv stuff
70 | pyvenv.cfg
71 | pip-selfcheck.json
72 | venv
73 | .venv
74 | Pipfile*
75 | share/*
76 |
77 | # vimmy stuff
78 | *.swp
79 | *.swo
80 |
81 | ctags.tmp
82 |
83 | # Visual Studio Code
84 | .vscode/
85 | !.vscode/tasks.json
86 |
87 | # Built docs
88 | docs/build
89 |
90 | # Windows Explorer
91 | desktop.ini
92 | /.vs/*
93 |
94 | # mypy
95 | /.mypy_cache/*
96 |
97 | # cython generated source
98 | *.c
99 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | exclude: '^aioesphomeapi/api.*$'
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-added-large-files
10 | - repo: https://github.com/asottile/pyupgrade
11 | rev: v3.20.0
12 | hooks:
13 | - id: pyupgrade
14 | args: [--py39-plus]
15 | - repo: https://github.com/astral-sh/ruff-pre-commit
16 | rev: v0.11.13
17 | hooks:
18 | - id: ruff
19 | args: [--fix]
20 | - id: ruff-format
21 | - repo: https://github.com/cdce8p/python-typing-update
22 | rev: v0.7.2
23 | hooks:
24 | - id: python-typing-update
25 | stages: [manual]
26 | args:
27 | - --py39-plus
28 | - --force
29 | - --keep-updates
30 | files: ^(aioesphomeapi)/.+\.py$
31 | - repo: https://github.com/MarcoGorelli/cython-lint
32 | rev: v0.16.6
33 | hooks:
34 | - id: cython-lint
35 | - id: double-quote-cython-strings
36 | - repo: https://github.com/pre-commit/mirrors-mypy
37 | rev: v1.16.0
38 | hooks:
39 | - id: mypy
40 | additional_dependencies: ["aiohappyeyeballs>=2.3.0", "noiseprotocol>=0.3.1,<1.0", "cryptography>=43.0.0", "zeroconf>=0.143.0,<1.0"]
41 | files: ^((aioesphomeapi)/.+)?[^/]+\.(py)$
42 |
--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "2.0.0",
3 | "tasks": [
4 | {
5 | "label": "Generate proto files",
6 | "type": "shell",
7 | "command": "script/gen-protoc",
8 | "group": {
9 | "kind": "build",
10 | "isDefault": true
11 | },
12 | "presentation": {
13 | "reveal": "never",
14 | "close": true,
15 | "panel": "new"
16 | },
17 | "problemMatcher": []
18 | }
19 | ]
20 | }
21 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 |
3 | ENV PROTOC_VERSION 25.5
4 |
5 | ARG TARGETARCH
6 |
7 | RUN if [ "$TARGETARCH" = "amd64" ]; then \
8 | arch="x86_64"; \
9 | elif [ "$TARGETARCH" = "arm64" ]; then \
10 | arch="aarch_64"; \
11 | else \
12 | exit 1; \
13 | fi \
14 | && curl -o /tmp/protoc.zip -L \
15 | https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${arch}.zip \
16 | && unzip /tmp/protoc.zip -d /usr -x readme.txt \
17 | && rm /tmp/protoc.zip
18 |
19 | RUN useradd -ms /bin/bash esphome
20 |
21 | USER esphome
22 |
23 | WORKDIR /aioesphomeapi
24 |
25 | COPY requirements/test.txt ./requirements_test.txt
26 |
27 | RUN pip3 install -r requirements_test.txt
28 |
29 | CMD ["script/gen-protoc"]
30 |
31 | LABEL \
32 | org.opencontainers.image.title="aioesphomeapi protobuf generator" \
33 | org.opencontainers.image.description="An image to help with ESPHomes aioesphomeapi protobuf generation" \
34 | org.opencontainers.image.vendor="ESPHome" \
35 | org.opencontainers.image.licenses="MIT" \
36 | org.opencontainers.image.url="https://esphome.io" \
37 | org.opencontainers.image.source="https://github.com/esphome/aioesphomeapi" \
38 | org.opencontainers.image.documentation="https://github.com/esphome/aioesphomeapi/blob/main/README.md"
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Otto Winter
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MAINTAINERS.md:
--------------------------------------------------------------------------------
1 | # Maintaining Notes
2 |
3 | Releases are automatically drafted by [release-drafter](https://github.com/release-drafter/release-drafter), the next version number is automatically computed by the labels of PRs in the release.
4 |
5 | See also .github/release-drafter.yml, if this label is in any PR, then the version change is marked that type of version change:
6 |
7 | - major release (+1.0.0): breaking-change, major
8 | - minor release (+0.1.0): minor, new-feature
9 | - patch (+0.0.1): this is the default release type
10 |
11 | Before creating a release: Check the latest commit passes continuous integration.
12 |
13 | When the release button on the draft is clicked, GitHub Actions will publish the release to PyPi.
14 |
15 | After any push to the main branch, the "protoc-update" workflow is run which updates the generated python protobuf files. This is to ensure that if a contributor has a newer protoc version installed than the protobuf python package, we won't run into any issues.
16 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include requirements/base.txt
3 | include aioesphomeapi/py.typed
4 | global-exclude *.c
5 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | aioesphomeapi
2 | =============
3 |
4 | .. image:: https://github.com/esphome/aioesphomeapi/workflows/CI/badge.svg
5 | :target: https://github.com/esphome/aioesphomeapi?query=workflow%3ACI+branch%3Amain
6 |
7 | .. image:: https://img.shields.io/pypi/v/aioesphomeapi.svg
8 | :target: https://pypi.python.org/pypi/aioesphomeapi
9 |
10 | .. image:: https://codecov.io/gh/esphome/aioesphomeapi/branch/main/graph/badge.svg
11 | :target: https://app.codecov.io/gh/esphome/aioesphomeapi/tree/main
12 |
13 | .. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json
14 | :target: https://codspeed.io/esphome/aioesphomeapi
15 |
16 | ``aioesphomeapi`` allows you to interact with devices flashed with `ESPHome `_.
17 |
18 | Installation
19 | ------------
20 |
21 | The module is available from the `Python Package Index `_.
22 |
23 | .. code:: bash
24 |
25 | $ pip3 install aioesphomeapi
26 |
27 | An optional cython extension is available for better performance, and the module will try to build it automatically.
28 |
29 | The extension requires a C compiler and Python development headers. The module will fall back to the pure Python implementation if they are unavailable.
30 |
31 | Building the extension can be forcefully disabled by setting the environment variable ``SKIP_CYTHON`` to ``1``.
32 |
33 | Usage
34 | -----
35 |
36 | It's required that you enable the `Native API `_ component for the device.
37 |
38 | .. code:: yaml
39 |
40 | # Example configuration entry
41 | api:
42 | password: 'MyPassword'
43 |
44 | Check the output to get the local address of the device or use the ``name:``under ``esphome:`` from the device configuration.
45 |
46 | .. code:: bash
47 |
48 | [17:56:38][C][api:095]: API Server:
49 | [17:56:38][C][api:096]: Address: api_test.local:6053
50 |
51 |
52 | The sample code below will connect to the device and retrieve details.
53 |
54 | .. code:: python
55 |
56 | import aioesphomeapi
57 | import asyncio
58 |
59 | async def main():
60 | """Connect to an ESPHome device and get details."""
61 |
62 | # Establish connection
63 | api = aioesphomeapi.APIClient("api_test.local", 6053, "MyPassword")
64 | await api.connect(login=True)
65 |
66 | # Get API version of the device's firmware
67 | print(api.api_version)
68 |
69 | # Show device details
70 | device_info = await api.device_info()
71 | print(device_info)
72 |
73 | # List all entities of the device
74 | entities = await api.list_entities_services()
75 | print(entities)
76 |
77 | loop = asyncio.get_event_loop()
78 | loop.run_until_complete(main())
79 |
80 | Subscribe to state changes of an ESPHome device.
81 |
82 | .. code:: python
83 |
84 | import aioesphomeapi
85 | import asyncio
86 |
87 | async def main():
88 | """Connect to an ESPHome device and wait for state changes."""
89 | cli = aioesphomeapi.APIClient("api_test.local", 6053, "MyPassword")
90 |
91 | await cli.connect(login=True)
92 |
93 | def change_callback(state):
94 | """Print the state changes of the device.."""
95 | print(state)
96 |
97 | # Subscribe to the state changes
98 | cli.subscribe_states(change_callback)
99 |
100 | loop = asyncio.get_event_loop()
101 | try:
102 | asyncio.ensure_future(main())
103 | loop.run_forever()
104 | except KeyboardInterrupt:
105 | pass
106 | finally:
107 | loop.close()
108 |
109 | Other examples:
110 |
111 | - `Camera `_
112 | - `Async print `_
113 | - `Simple print `_
114 | - `InfluxDB `_
115 |
116 | Development
117 | -----------
118 |
119 | For development is recommended to use a Python virtual environment (``venv``).
120 |
121 | .. code:: bash
122 |
123 | # Setup virtualenv (optional)
124 | $ python3 -m venv .
125 | $ source bin/activate
126 | # Install aioesphomeapi and development depenencies
127 | $ pip3 install -e .
128 | $ pip3 install -r requirements/test.txt
129 |
130 | # Run linters & test
131 | $ script/lint
132 | # Update protobuf _pb2.py definitions (requires a protobuf compiler installation)
133 | $ script/gen-protoc
134 |
135 | A cli tool is also available for watching logs:
136 |
137 | .. code:: bash
138 |
139 | aioesphomeapi-logs --help
140 |
141 | A cli tool is also available to discover devices:
142 |
143 | .. code:: bash
144 |
145 | aioesphomeapi-discover --help
146 |
147 |
148 | License
149 | -------
150 |
151 | ``aioesphomeapi`` is licensed under MIT, for more details check LICENSE.
152 |
--------------------------------------------------------------------------------
/aioesphomeapi/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .api_pb2 import ( # type: ignore[attr-defined] # noqa: F401
3 | BluetoothLERawAdvertisement,
4 | BluetoothLERawAdvertisementsResponse,
5 | )
6 | from .ble_defs import ESP_CONNECTION_ERROR_DESCRIPTION, BLEConnectionError
7 | from .client import APIClient
8 | from .connection import APIConnection, ConnectionParams
9 | from .core import (
10 | ESPHOME_GATT_ERRORS,
11 | MESSAGE_TYPE_TO_PROTO,
12 | APIConnectionError,
13 | EncryptionPlaintextAPIError,
14 | BadNameAPIError,
15 | BluetoothConnectionDroppedError,
16 | HandshakeAPIError,
17 | InvalidAuthAPIError,
18 | InvalidEncryptionKeyAPIError,
19 | ProtocolAPIError,
20 | RequiresEncryptionAPIError,
21 | ResolveAPIError,
22 | EncryptionHelloAPIError,
23 | SocketAPIError,
24 | BadMACAddressAPIError,
25 | )
26 | from .model import *
27 | from .reconnect_logic import ReconnectLogic
28 | from .log_parser import parse_log_message, LogParser
29 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/base.pxd:
--------------------------------------------------------------------------------
1 |
2 | import cython
3 |
4 | from ..connection cimport APIConnection
5 |
6 |
7 | cdef bint TYPE_CHECKING
8 |
9 | cdef class APIFrameHelper:
10 |
11 | cdef object _loop
12 | cdef APIConnection _connection
13 | cdef object _transport
14 | cdef public object _writelines
15 | cdef public object ready_future
16 | cdef bytes _buffer
17 | cdef unsigned int _buffer_len
18 | cdef unsigned int _pos
19 | cdef object _client_info
20 | cdef str _log_name
21 |
22 | cpdef set_log_name(self, str log_name)
23 |
24 | @cython.locals(
25 | original_pos="unsigned int",
26 | new_pos="unsigned int",
27 | cstr="const unsigned char *"
28 | )
29 | cdef bytes _read(self, int length)
30 |
31 | @cython.locals(bytes_data=bytes)
32 | cdef void _add_to_buffer(self, object data) except *
33 |
34 | @cython.locals(end_of_frame_pos="unsigned int", cstr="const unsigned char *")
35 | cdef void _remove_from_buffer(self) except *
36 |
37 | cpdef void write_packets(self, list packets, bint debug_enabled) except *
38 |
39 | cdef void _write_bytes(self, object data, bint debug_enabled) except *
40 |
41 | cdef void _handle_error_and_close(self, Exception exc) except *
42 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import abstractmethod
4 | import asyncio
5 | from collections.abc import Iterable
6 | import logging
7 | from typing import TYPE_CHECKING, Callable, cast
8 |
9 | from ..core import SocketClosedAPIError
10 |
11 | if TYPE_CHECKING:
12 | from ..connection import APIConnection
13 |
14 | _LOGGER = logging.getLogger(__name__)
15 |
16 | SOCKET_ERRORS = (
17 | ConnectionResetError,
18 | asyncio.IncompleteReadError,
19 | OSError,
20 | TimeoutError,
21 | )
22 |
23 |
24 | _int = int
25 | _bytes = bytes
26 |
27 |
28 | class APIFrameHelper:
29 | """Helper class to handle the API frame protocol."""
30 |
31 | __slots__ = (
32 | "_buffer",
33 | "_buffer_len",
34 | "_client_info",
35 | "_connection",
36 | "_log_name",
37 | "_loop",
38 | "_pos",
39 | "_transport",
40 | "_writelines",
41 | "ready_future",
42 | )
43 |
44 | def __init__(
45 | self,
46 | connection: APIConnection,
47 | client_info: str,
48 | log_name: str,
49 | ) -> None:
50 | """Initialize the API frame helper."""
51 | loop = asyncio.get_running_loop()
52 | self._loop = loop
53 | self._connection = connection
54 | self._transport: asyncio.Transport | None = None
55 | self._writelines: (
56 | None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
57 | ) = None
58 | self.ready_future = self._loop.create_future()
59 | self._buffer: bytes | None = None
60 | self._buffer_len = 0
61 | self._pos = 0
62 | self._client_info = client_info
63 | self._log_name = log_name
64 |
65 | def set_log_name(self, log_name: str) -> None:
66 | """Set the log name."""
67 | self._log_name = log_name
68 |
69 | def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
70 | if not self.ready_future.done():
71 | self.ready_future.set_exception(exc)
72 |
73 | def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
74 | """Add data to the buffer."""
75 | # Protractor sends a bytearray, so we need to convert it to bytes
76 | # https://github.com/esphome/issues/issues/5117
77 | # type(data) should not be isinstance(data, bytes) because we want to
78 | # to explicitly check for bytes and not for subclasses of bytes
79 | bytes_data = bytes(data) if type(data) is not bytes else data
80 | if self._buffer_len == 0:
81 | # This is the best case scenario, we don't have to copy the data
82 | # and can just use the buffer directly. This is the most common
83 | # case as well.
84 | self._buffer = bytes_data
85 | else:
86 | if TYPE_CHECKING:
87 | assert self._buffer is not None, "Buffer should be set"
88 | # This is the worst case scenario, we have to copy the bytes_data
89 | # and can't just use the buffer directly. This is also very
90 | # uncommon since we usually read the entire frame at once.
91 | self._buffer += bytes_data
92 | self._buffer_len += len(bytes_data)
93 |
94 | def _remove_from_buffer(self) -> None:
95 | """Remove data from the buffer."""
96 | end_of_frame_pos = self._pos
97 | self._buffer_len -= end_of_frame_pos
98 | if self._buffer_len == 0:
99 | # This is the best case scenario, we can just set the buffer to None
100 | # and don't have to copy the data. This is the most common case as well.
101 | self._buffer = None
102 | return
103 | if TYPE_CHECKING:
104 | assert self._buffer is not None, "Buffer should be set"
105 | # This is the worst case scenario, we have to copy the data
106 | # and can't just use the buffer directly. This should only happen
107 | # when we read multiple frames at once because the event loop
108 | # is blocked and we cannot pull the data out of the buffer fast enough.
109 | cstr = self._buffer
110 | # Important: we must use the explicit length for the slice
111 | # since Cython will stop at any '\0' character if we don't
112 | self._buffer = cstr[end_of_frame_pos : self._buffer_len + end_of_frame_pos]
113 |
114 | def _read(self, length: _int) -> bytes | None:
115 | """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
116 | new_pos = self._pos + length
117 | if self._buffer_len < new_pos:
118 | return None
119 | original_pos = self._pos
120 | self._pos = new_pos
121 | if TYPE_CHECKING:
122 | assert self._buffer is not None, "Buffer should be set"
123 | cstr = self._buffer
124 | # Important: we must keep the bounds check (self._buffer_len < new_pos)
125 | # above to verify we never try to read past the end of the buffer
126 | return cstr[original_pos:new_pos]
127 |
128 | @abstractmethod
129 | def write_packets(
130 | self, packets: list[tuple[int, bytes]], debug_enabled: bool
131 | ) -> None:
132 | """Write a packets to the socket.
133 |
134 | Packets are in the format of tuple[protobuf_type, protobuf_data]
135 | """
136 |
137 | def connection_made(self, transport: asyncio.BaseTransport) -> None:
138 | """Handle a new connection."""
139 | self._transport = cast(asyncio.Transport, transport)
140 | self._writelines = self._transport.writelines
141 |
142 | def _handle_error_and_close(self, exc: Exception) -> None:
143 | """Handle an error and close the connection.
144 |
145 | May not be overridden by subclasses.
146 | """
147 | self._handle_error(exc)
148 | self.close()
149 |
150 | def _handle_error(self, exc: Exception) -> None:
151 | """Handle an error.
152 |
153 | May be overridden by subclasses.
154 | """
155 | self._set_ready_future_exception(exc)
156 | self._connection.report_fatal_error(exc)
157 |
158 | def connection_lost(self, exc: Exception | None) -> None:
159 | """Handle the connection being lost."""
160 | self._handle_error(
161 | exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
162 | )
163 |
164 | def eof_received(self) -> bool | None:
165 | """Handle EOF received."""
166 | self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
167 | return False
168 |
169 | def close(self) -> None:
170 | """Close the connection."""
171 | if self._transport:
172 | self._transport.close()
173 | self._transport = None
174 | self._writelines = None
175 |
176 | def pause_writing(self) -> None:
177 | """Stub."""
178 |
179 | def resume_writing(self) -> None:
180 | """Stub."""
181 |
182 | def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
183 | """Write bytes to the socket."""
184 | if debug_enabled:
185 | _LOGGER.debug(
186 | "%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
187 | )
188 |
189 | if TYPE_CHECKING:
190 | assert self._writelines is not None, "Writer is not set"
191 |
192 | self._writelines(data)
193 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/noise.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 |
3 | from ..connection cimport APIConnection
4 | from .base cimport APIFrameHelper
5 | from .noise_encryption cimport EncryptCipher, DecryptCipher
6 | from .packets cimport make_noise_packets
7 |
8 | cdef bint TYPE_CHECKING
9 |
10 | cdef unsigned int NOISE_STATE_HELLO
11 | cdef unsigned int NOISE_STATE_HANDSHAKE
12 | cdef unsigned int NOISE_STATE_READY
13 | cdef unsigned int NOISE_STATE_CLOSED
14 |
15 | cdef bytes NOISE_HELLO
16 | cdef object InvalidTag
17 | cdef object ESPHOME_NOISE_BACKEND
18 |
19 | cdef class APINoiseFrameHelper(APIFrameHelper):
20 |
21 | cdef object _noise_psk
22 | cdef str _expected_name
23 | cdef str _expected_mac
24 | cdef unsigned int _state
25 | cdef str _server_mac
26 | cdef str _server_name
27 | cdef object _proto
28 | cdef EncryptCipher _encrypt_cipher
29 | cdef DecryptCipher _decrypt_cipher
30 |
31 | @cython.locals(
32 | header=bytes,
33 | preamble="unsigned char",
34 | header="const unsigned char *"
35 | )
36 | cpdef void data_received(self, object data) except *
37 |
38 | @cython.locals(
39 | msg=bytes,
40 | msg_type="unsigned int",
41 | payload=bytes,
42 | msg_length=Py_ssize_t,
43 | msg_cstr="const unsigned char *",
44 | )
45 | cdef void _handle_frame(self, bytes frame) except *
46 |
47 | @cython.locals(
48 | chosen_proto=char,
49 | server_name_i=int,
50 | mac_address_i=int,
51 | mac_address=str,
52 | server_name=str,
53 | )
54 | cdef void _handle_hello(self, bytes server_hello) except *
55 |
56 | cdef void _handle_handshake(self, bytes msg) except *
57 |
58 | cdef void _handle_closed(self, bytes frame) except *
59 |
60 | @cython.locals(handshake_frame=bytearray, frame_len="unsigned int")
61 | cdef void _send_hello_handshake(self) except *
62 |
63 | cdef void _setup_proto(self) except *
64 |
65 | @cython.locals(psk_bytes=bytes)
66 | cdef _decode_noise_psk(self)
67 |
68 | cpdef void write_packets(self, list packets, bint debug_enabled) except *
69 |
70 | cdef _error_on_incorrect_preamble(self, bytes msg)
71 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/noise.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | import binascii
5 | import logging
6 | from typing import TYPE_CHECKING
7 |
8 | from cryptography.exceptions import InvalidTag
9 | from noise.connection import NoiseConnection
10 |
11 | from ..core import (
12 | APIConnectionError,
13 | BadMACAddressAPIError,
14 | BadNameAPIError,
15 | EncryptionErrorAPIError,
16 | EncryptionHelloAPIError,
17 | EncryptionPlaintextAPIError,
18 | HandshakeAPIError,
19 | InvalidEncryptionKeyAPIError,
20 | ProtocolAPIError,
21 | )
22 | from .base import _LOGGER, APIFrameHelper
23 | from .noise_encryption import ESPHOME_NOISE_BACKEND, DecryptCipher, EncryptCipher
24 | from .packets import make_noise_packets
25 |
26 | if TYPE_CHECKING:
27 | from ..connection import APIConnection
28 |
29 |
30 | # This is effectively an enum but we don't want to use an enum
31 | # because we have a simple dispatch in the data_received method
32 | # that would be more complicated with an enum and we want to add
33 | # cdefs for each different state so we have a good test for each
34 | # state receiving data since we found that the protractor event
35 | # loop will send use a bytearray instead of bytes was not handled
36 | # correctly.
37 | NOISE_STATE_HELLO = 1
38 | NOISE_STATE_HANDSHAKE = 2
39 | NOISE_STATE_READY = 3
40 | NOISE_STATE_CLOSED = 4
41 |
42 |
43 | NOISE_HELLO = b"\x01\x00\x00"
44 |
45 | int_ = int
46 |
47 |
48 | class APINoiseFrameHelper(APIFrameHelper):
49 | """Frame helper for noise encrypted connections."""
50 |
51 | __slots__ = (
52 | "_decrypt_cipher",
53 | "_encrypt_cipher",
54 | "_expected_mac",
55 | "_expected_name",
56 | "_noise_psk",
57 | "_proto",
58 | "_server_mac",
59 | "_server_name",
60 | "_state",
61 | )
62 |
63 | def __init__(
64 | self,
65 | connection: APIConnection,
66 | noise_psk: str,
67 | expected_name: str | None,
68 | expected_mac: str | None,
69 | client_info: str,
70 | log_name: str,
71 | ) -> None:
72 | """Initialize the API frame helper."""
73 | super().__init__(connection, client_info, log_name)
74 | self._noise_psk = noise_psk
75 | self._expected_mac = expected_mac
76 | self._expected_name = expected_name
77 | self._state = NOISE_STATE_HELLO
78 | self._server_name: str | None = None
79 | self._server_mac: str | None = None
80 | self._encrypt_cipher: EncryptCipher | None = None
81 | self._decrypt_cipher: DecryptCipher | None = None
82 | self._setup_proto()
83 |
84 | def close(self) -> None:
85 | """Close the connection."""
86 | # Make sure we set the ready event if its not already set
87 | # so that we don't block forever on the ready event if we
88 | # are waiting for the handshake to complete.
89 | self._set_ready_future_exception(
90 | APIConnectionError(f"{self._log_name}: Connection closed")
91 | )
92 | self._state = NOISE_STATE_CLOSED
93 | super().close()
94 |
95 | def _handle_error(self, exc: Exception) -> None:
96 | """Handle an error, and provide a good message when during hello."""
97 | if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):
98 | original_exc: Exception = exc
99 | exc = EncryptionHelloAPIError(
100 | f"{self._log_name}: The connection dropped immediately after encrypted hello; "
101 | "Try enabling encryption on the device or turning off "
102 | f"encryption on the client ({self._client_info})"
103 | )
104 | exc.__cause__ = original_exc
105 | super()._handle_error(exc)
106 |
107 | def connection_made(self, transport: asyncio.BaseTransport) -> None:
108 | """Handle a new connection."""
109 | super().connection_made(transport)
110 | self._send_hello_handshake()
111 |
112 | def data_received(self, data: bytes | bytearray | memoryview) -> None:
113 | self._add_to_buffer(data)
114 | # Message header is 3 bytes
115 | while self._buffer_len >= 3:
116 | if TYPE_CHECKING:
117 | assert self._buffer is not None, "Buffer should be set"
118 | self._pos = 3
119 | header = self._buffer
120 | preamble = header[0]
121 | if preamble != 0x01:
122 | if preamble == 0x00:
123 | self._handle_error_and_close(
124 | EncryptionPlaintextAPIError(
125 | f"{self._log_name}: The device is using plaintext protocol; "
126 | "Try enabling encryption on the device or turning off "
127 | f"encryption on the client ({self._client_info})"
128 | )
129 | )
130 | else:
131 | self._handle_error_and_close(
132 | ProtocolAPIError(
133 | f"{self._log_name}: Marker byte invalid: {preamble}"
134 | )
135 | )
136 | return
137 | if (frame := self._read((header[1] << 8) | header[2])) is None:
138 | # The complete frame is not yet available, wait for more data
139 | # to arrive before continuing, since callback_packet has not
140 | # been called yet the buffer will not be cleared and the next
141 | # call to data_received will continue processing the packet
142 | # at the start of the frame.
143 | return
144 |
145 | # asyncio already runs data_received in a try block
146 | # which will call connection_lost if an exception is raised
147 | if self._state == NOISE_STATE_READY:
148 | self._handle_frame(frame)
149 | elif self._state == NOISE_STATE_HELLO:
150 | self._handle_hello(frame)
151 | elif self._state == NOISE_STATE_HANDSHAKE:
152 | self._handle_handshake(frame)
153 | else:
154 | self._handle_closed(frame)
155 |
156 | self._remove_from_buffer()
157 |
158 | def _send_hello_handshake(self) -> None:
159 | """Send a ClientHello to the server."""
160 | handshake_frame = self._proto.write_message()
161 | frame_len = len(handshake_frame) + 1
162 | header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
163 | self._write_bytes(
164 | (NOISE_HELLO, header, b"\x00", handshake_frame),
165 | _LOGGER.isEnabledFor(logging.DEBUG),
166 | )
167 |
168 | def _handle_hello(self, server_hello: bytes) -> None:
169 | """Perform the handshake with the server."""
170 | if not server_hello:
171 | self._handle_error_and_close(
172 | HandshakeAPIError(f"{self._log_name}: ServerHello is empty")
173 | )
174 | return
175 |
176 | # First byte of server hello is the protocol the server chose
177 | # for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
178 | # exists.
179 | chosen_proto = server_hello[0]
180 | if chosen_proto != 0x01:
181 | self._handle_error_and_close(
182 | HandshakeAPIError(
183 | f"{self._log_name}: Unknown protocol selected by client {chosen_proto}"
184 | )
185 | )
186 | return
187 |
188 | # Check name matches expected name (for noise sessions, this is done
189 | # during hello phase before a connection is set up)
190 | # Server name is encoded as a string followed by a zero byte after the chosen proto byte
191 | server_name_i = server_hello.find(b"\0", 1)
192 | if server_name_i != -1:
193 | # server name found, this extension was added in 2022.2
194 | server_name = server_hello[1:server_name_i].decode()
195 | self._server_name = server_name
196 |
197 | if self._expected_name is not None and self._expected_name != server_name:
198 | self._handle_error_and_close(
199 | BadNameAPIError(
200 | f"{self._log_name}: Server sent a different name '{server_name}'",
201 | server_name,
202 | )
203 | )
204 | return
205 |
206 | mac_address_i = server_hello.find(b"\0", server_name_i + 1)
207 | if mac_address_i != -1:
208 | # mac address found, this extension was added in 2025.4
209 | mac_address = server_hello[server_name_i + 1 : mac_address_i].decode()
210 | self._server_mac = mac_address
211 | if self._expected_mac is not None and self._expected_mac != mac_address:
212 | self._handle_error_and_close(
213 | BadMACAddressAPIError(
214 | f"{self._log_name}: Server sent a different mac '{mac_address}'",
215 | server_name,
216 | mac_address,
217 | )
218 | )
219 | return
220 |
221 | self._state = NOISE_STATE_HANDSHAKE
222 |
223 | def _decode_noise_psk(self) -> bytes:
224 | """Decode the given noise psk from base64 format to raw bytes."""
225 | psk = self._noise_psk
226 | server_name = self._server_name
227 | server_mac = self._server_mac
228 | try:
229 | psk_bytes = binascii.a2b_base64(psk)
230 | except ValueError:
231 | raise InvalidEncryptionKeyAPIError(
232 | f"{self._log_name}: Malformed PSK `{psk}`, expected "
233 | "base64-encoded value",
234 | server_name,
235 | server_mac,
236 | )
237 | if len(psk_bytes) != 32:
238 | raise InvalidEncryptionKeyAPIError(
239 | f"{self._log_name}:Malformed PSK `{psk}`, expected"
240 | f" 32-bytes of base64 data",
241 | server_name,
242 | server_mac,
243 | )
244 | return psk_bytes
245 |
246 | def _setup_proto(self) -> None:
247 | """Set up the noise protocol."""
248 | proto = NoiseConnection.from_name(
249 | b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
250 | )
251 | proto.set_as_initiator()
252 | proto.set_psks(self._decode_noise_psk())
253 | proto.set_prologue(b"NoiseAPIInit\x00\x00")
254 | proto.start_handshake()
255 | self._proto = proto
256 |
257 | def _error_on_incorrect_preamble(self, msg: bytes) -> None:
258 | """Handle an incorrect preamble."""
259 | explanation = msg[1:].decode()
260 | if explanation != "Handshake MAC failure":
261 | exc = HandshakeAPIError(
262 | f"{self._log_name}: Handshake failure: {explanation}"
263 | )
264 | else:
265 | exc = InvalidEncryptionKeyAPIError(
266 | f"{self._log_name}: Invalid encryption key",
267 | self._server_name,
268 | self._server_mac,
269 | )
270 | self._handle_error_and_close(exc)
271 |
272 | def _handle_handshake(self, msg: bytes) -> None:
273 | if msg[0] != 0:
274 | self._error_on_incorrect_preamble(msg)
275 | return
276 | self._proto.read_message(msg[1:])
277 | self._state = NOISE_STATE_READY
278 | noise_protocol = self._proto.noise_protocol
279 | self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member
280 | self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member
281 | self.ready_future.set_result(None)
282 |
283 | def write_packets(
284 | self, packets: list[tuple[int, bytes]], debug_enabled: bool
285 | ) -> None:
286 | """Write a packets to the socket.
287 |
288 | Packets are in the format of tuple[protobuf_type, protobuf_data]
289 | """
290 | if TYPE_CHECKING:
291 | assert self._encrypt_cipher is not None, "Handshake should be complete"
292 | self._write_bytes(
293 | make_noise_packets(packets, self._encrypt_cipher), debug_enabled
294 | )
295 |
296 | def _handle_frame(self, frame: bytes) -> None:
297 | """Handle an incoming frame."""
298 | if TYPE_CHECKING:
299 | assert self._decrypt_cipher is not None, "Handshake should be complete"
300 | try:
301 | msg = self._decrypt_cipher.decrypt(frame)
302 | except InvalidTag:
303 | # This shouldn't happen since we already checked the tag during handshake
304 | # but it could happen if the server sends a bad frame see
305 | # issue https://github.com/esphome/aioesphomeapi/issues/1044
306 | self._handle_error_and_close(
307 | EncryptionErrorAPIError(
308 | f"{self._log_name}: Encryption error", self._server_name
309 | )
310 | )
311 | return
312 | msg_length = len(msg)
313 | msg_cstr = msg
314 | if msg_length < 4:
315 | # Important: we must bound check msg_length to ensure we
316 | # do not read past the end of the message in the payload
317 | # slicing below
318 | self._handle_error_and_close(
319 | ProtocolAPIError(
320 | f"{self._log_name}: Decrypted message too short: {msg_length} bytes"
321 | )
322 | )
323 | return
324 | # Message layout is
325 | # 2 bytes: message type (0:type_high, 1:type_low)
326 | # 2 bytes: message length (2:length_high, 3:length_low)
327 | # - We ignore the message length field because we do not
328 | # trust the remote end to send the correct length
329 | # N bytes: message data (4:...)
330 | msg_type = (msg_cstr[0] << 8) | msg_cstr[1]
331 | # Important: we must explicitly use msg_length here since msg_cstr
332 | # is a cstring and Cython will stop at the first null byte if we
333 | # do not use msg_length
334 | payload = msg_cstr[4:msg_length]
335 | self._connection.process_packet(msg_type, payload)
336 |
337 | def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
338 | """Handle a closed frame."""
339 | self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
340 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/noise_encryption.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 | from libc.stdint cimport uint64_t
3 | from .pack cimport fast_pack_nonce
4 |
5 | cdef object PACK_NONCE
6 |
7 | cdef class EncryptCipher:
8 |
9 | cdef uint64_t _nonce
10 | cdef object _encrypt
11 |
12 | cpdef bytes encrypt(self, object frame)
13 |
14 | cdef class DecryptCipher:
15 |
16 | cdef uint64_t _nonce
17 | cdef object _decrypt
18 |
19 | cdef bytes decrypt(self, object frame)
20 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/noise_encryption.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 | from struct import Struct
5 | from typing import Any
6 |
7 | from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
8 | from noise.backends.default import DefaultNoiseBackend
9 | from noise.backends.default.ciphers import ChaCha20Cipher, CryptographyCipher
10 | from noise.state import CipherState
11 |
12 | _bytes = bytes
13 | _int = int
14 |
15 | PACK_NONCE = partial(Struct(" type[ChaCha20Poly1305Reusable]:
30 | return ChaCha20Poly1305Reusable # type: ignore[no-any-return, unused-ignore]
31 |
32 |
33 | class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc]
34 | def __init__(self, *args: Any, **kwargs: Any) -> None:
35 | super().__init__(*args, **kwargs)
36 | self.ciphers["ChaChaPoly"] = ChaCha20CipherReuseable
37 |
38 |
39 | ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
40 |
41 |
42 | class EncryptCipher:
43 | """Wrapper around the ChaCha20Poly1305 cipher for encryption."""
44 |
45 | __slots__ = ("_encrypt", "_nonce")
46 |
47 | def __init__(self, cipher_state: CipherState) -> None:
48 | """Initialize the cipher wrapper."""
49 | crypto_cipher: CryptographyCipher = cipher_state.cipher
50 | cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
51 | self._nonce: _int = cipher_state.n
52 | self._encrypt = cipher.encrypt
53 |
54 | def encrypt(self, data: _bytes) -> bytes:
55 | """Encrypt a frame."""
56 | ciphertext = self._encrypt(fast_pack_nonce(self._nonce), data, None)
57 | self._nonce += 1
58 | return ciphertext # type: ignore[no-any-return, unused-ignore]
59 |
60 |
61 | class DecryptCipher:
62 | """Wrapper around the ChaCha20Poly1305 cipher for decryption."""
63 |
64 | __slots__ = ("_decrypt", "_nonce")
65 |
66 | def __init__(self, cipher_state: CipherState) -> None:
67 | """Initialize the cipher wrapper."""
68 | crypto_cipher: CryptographyCipher = cipher_state.cipher
69 | cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
70 | self._nonce: _int = cipher_state.n
71 | self._decrypt = cipher.decrypt
72 |
73 | def decrypt(self, data: _bytes) -> bytes:
74 | """Decrypt a frame."""
75 | plaintext = self._decrypt(fast_pack_nonce(self._nonce), data, None)
76 | self._nonce += 1
77 | return plaintext # type: ignore[no-any-return, unused-ignore]
78 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/pack.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 | from libc.stdint cimport uint64_t
3 |
4 | cpdef bytes fast_pack_nonce(uint64_t q)
5 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/pack.pyx:
--------------------------------------------------------------------------------
1 | from cpython.bytes cimport PyBytes_FromStringAndSize
2 | from libc.stdint cimport uint64_t
3 |
4 | cpdef bytes fast_pack_nonce(uint64_t q):
5 | cdef:
6 | char buf[12]
7 | char *p = buf
8 |
9 | # First 4 bytes are zero
10 | p[0] = p[1] = p[2] = p[3] = 0
11 |
12 | # q (uint64_t) in little-endian
13 | p[4] = (q & 0xFF)
14 | p[5] = ((q >> 8) & 0xFF)
15 | p[6] = ((q >> 16) & 0xFF)
16 | p[7] = ((q >> 24) & 0xFF)
17 | p[8] = ((q >> 32) & 0xFF)
18 | p[9] = ((q >> 40) & 0xFF)
19 | p[10] = ((q >> 48) & 0xFF)
20 | p[11] = ((q >> 56) & 0xFF)
21 |
22 | return PyBytes_FromStringAndSize(buf, 12)
23 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/packets.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 |
3 | from .noise_encryption cimport EncryptCipher
4 |
5 | cdef object varuint_to_bytes
6 |
7 | cpdef _varuint_to_bytes(int value)
8 |
9 |
10 | @cython.locals(
11 | type_="unsigned int",
12 | data=bytes,
13 | packet=tuple,
14 | type_=object
15 | )
16 | cpdef list make_plain_text_packets(list packets) except *
17 |
18 |
19 | @cython.locals(
20 | type_="unsigned int",
21 | data=bytes,
22 | data_header=bytes,
23 | packet=tuple,
24 | data_len=Py_ssize_t,
25 | frame=bytes,
26 | frame_len=Py_ssize_t,
27 | )
28 | cpdef list make_noise_packets(list packets, EncryptCipher encrypt_cipher) except *
29 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/packets.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | from .noise_encryption import EncryptCipher
4 |
5 | _int = int
6 |
7 |
8 | def _varuint_to_bytes(value: _int) -> bytes:
9 | """Convert a varuint to bytes."""
10 | if value <= 0x7F:
11 | return bytes((value,))
12 |
13 | result = bytearray()
14 | while value:
15 | temp = value & 0x7F
16 | value >>= 7
17 | if value:
18 | result.append(temp | 0x80)
19 | else:
20 | result.append(temp)
21 |
22 | return bytes(result)
23 |
24 |
25 | _cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
26 | varuint_to_bytes = _cached_varuint_to_bytes
27 |
28 |
29 | def make_plain_text_packets(packets: list[tuple[int, bytes]]) -> list[bytes]:
30 | """Make a list of plain text packet."""
31 | out: list[bytes] = []
32 | for packet in packets:
33 | type_: int = packet[0]
34 | data: bytes = packet[1]
35 | out.append(b"\0")
36 | out.append(varuint_to_bytes(len(data)))
37 | out.append(varuint_to_bytes(type_))
38 | if data:
39 | out.append(data)
40 | return out
41 |
42 |
43 | def make_noise_packets(
44 | packets: list[tuple[int, bytes]], encrypt_cipher: EncryptCipher
45 | ) -> list[bytes]:
46 | """Make a list of noise packet."""
47 | out: list[bytes] = []
48 | for packet in packets:
49 | type_: int = packet[0]
50 | data: bytes = packet[1]
51 | data_len = len(data)
52 | data_header = bytes(
53 | (
54 | (type_ >> 8) & 0xFF,
55 | type_ & 0xFF,
56 | (data_len >> 8) & 0xFF,
57 | data_len & 0xFF,
58 | )
59 | )
60 | frame = encrypt_cipher.encrypt(data_header + data)
61 | frame_len = len(frame)
62 | header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
63 | out.append(header)
64 | out.append(frame)
65 | return out
66 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/plain_text.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 |
3 | from ..connection cimport APIConnection
4 | from .base cimport APIFrameHelper
5 | from .packets cimport make_plain_text_packets
6 |
7 | cdef bint TYPE_CHECKING
8 |
9 |
10 | cdef class APIPlaintextFrameHelper(APIFrameHelper):
11 |
12 | cpdef void data_received(self, object data) except *
13 |
14 | cdef void _error_on_incorrect_preamble(self, int preamble) except *
15 |
16 | @cython.locals(
17 | result="unsigned int",
18 | bitpos="unsigned int",
19 | cstr="const unsigned char *",
20 | val="unsigned char",
21 | current_pos="unsigned int"
22 | )
23 | cdef int _read_varuint(self)
24 |
25 | cpdef void write_packets(self, list packets, bint debug_enabled) except *
26 |
--------------------------------------------------------------------------------
/aioesphomeapi/_frame_helper/plain_text.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from typing import TYPE_CHECKING
5 |
6 | from ..core import ProtocolAPIError, RequiresEncryptionAPIError
7 | from .base import APIFrameHelper
8 | from .packets import make_plain_text_packets
9 |
10 | _int = int
11 |
12 |
13 | class APIPlaintextFrameHelper(APIFrameHelper):
14 | """Frame helper for plaintext API connections."""
15 |
16 | def connection_made(self, transport: asyncio.BaseTransport) -> None:
17 | """Handle a new connection."""
18 | super().connection_made(transport)
19 | self.ready_future.set_result(None)
20 |
21 | def write_packets(
22 | self, packets: list[tuple[int, bytes]], debug_enabled: bool
23 | ) -> None:
24 | """Write a packets to the socket.
25 |
26 | Packets are in the format of tuple[protobuf_type, protobuf_data]
27 |
28 | The entire packet must be written in a single call.
29 | """
30 | self._write_bytes(make_plain_text_packets(packets), debug_enabled)
31 |
32 | def _read_varuint(self) -> _int:
33 | """Read a varuint from the buffer or -1 if the buffer runs out of bytes."""
34 | if TYPE_CHECKING:
35 | assert self._buffer is not None, "Buffer should be set"
36 | result = 0
37 | bitpos = 0
38 | cstr = self._buffer
39 | while self._buffer_len > self._pos:
40 | val = cstr[self._pos]
41 | self._pos += 1
42 | result |= (val & 0x7F) << bitpos
43 | if (val & 0x80) == 0:
44 | return result
45 | bitpos += 7
46 | return -1
47 |
48 | def data_received(self, data: bytes | bytearray | memoryview) -> None:
49 | self._add_to_buffer(data)
50 | # Message header is at least 3 bytes, empty length allowed
51 | while self._buffer_len >= 3:
52 | self._pos = 0
53 | # Read preamble, which should always 0x00
54 | if (preamble := self._read_varuint()) != 0x00:
55 | self._error_on_incorrect_preamble(preamble)
56 | return
57 | if (length := self._read_varuint()) == -1:
58 | return
59 | if (msg_type := self._read_varuint()) == -1:
60 | return
61 |
62 | if length == 0:
63 | self._remove_from_buffer()
64 | self._connection.process_packet(msg_type, b"")
65 | continue
66 |
67 | # The packet data is not yet available, wait for more data
68 | # to arrive before continuing, since callback_packet has not
69 | # been called yet the buffer will not be cleared and the next
70 | # call to data_received will continue processing the packet
71 | # at the start of the frame.
72 | if (packet_data := self._read(length)) is None:
73 | return
74 | self._remove_from_buffer()
75 | self._connection.process_packet(msg_type, packet_data)
76 | # If we have more data, continue processing
77 |
78 | def _error_on_incorrect_preamble(self, preamble: _int) -> None:
79 | """Handle an incorrect preamble."""
80 | if preamble == 0x01:
81 | self._handle_error_and_close(
82 | RequiresEncryptionAPIError(
83 | f"{self._log_name}: Connection requires encryption"
84 | )
85 | )
86 | return
87 | self._handle_error_and_close(
88 | ProtocolAPIError(f"{self._log_name}: Invalid preamble {preamble:02x}")
89 | )
90 |
--------------------------------------------------------------------------------
/aioesphomeapi/api_options.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 | import "google/protobuf/descriptor.proto";
3 |
4 |
5 | enum APISourceType {
6 | SOURCE_BOTH = 0;
7 | SOURCE_SERVER = 1;
8 | SOURCE_CLIENT = 2;
9 | }
10 |
11 | message void {}
12 |
13 | extend google.protobuf.MethodOptions {
14 | optional bool needs_setup_connection = 1038 [default=true];
15 | optional bool needs_authentication = 1039 [default=true];
16 | }
17 |
18 | extend google.protobuf.MessageOptions {
19 | optional uint32 id = 1036 [default=0];
20 | optional APISourceType source = 1037 [default=SOURCE_BOTH];
21 | optional string ifdef = 1038;
22 | optional bool log = 1039 [default=true];
23 | optional bool no_delay = 1040 [default=false];
24 | }
25 |
--------------------------------------------------------------------------------
/aioesphomeapi/api_options_pb2.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | # -*- coding: utf-8 -*-
3 | # Generated by the protocol buffer compiler. DO NOT EDIT!
4 | # source: api_options.proto
5 | # Protobuf Python Version: 4.25.5
6 | """Generated protocol buffer code."""
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import descriptor_pool as _descriptor_pool
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf.internal import builder as _builder
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2
17 |
18 |
19 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse')
20 |
21 | _globals = globals()
22 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
23 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'api_options_pb2', _globals)
24 | if _descriptor._USE_C_DESCRIPTORS == False:
25 | DESCRIPTOR._options = None
26 | _globals['_APISOURCETYPE']._serialized_start=63
27 | _globals['_APISOURCETYPE']._serialized_end=133
28 | _globals['_VOID']._serialized_start=55
29 | _globals['_VOID']._serialized_end=61
30 | # @@protoc_insertion_point(module_scope)
31 |
--------------------------------------------------------------------------------
/aioesphomeapi/ble_defs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from enum import IntEnum
4 |
5 |
6 | class BLEConnectionError(IntEnum):
7 | """BLE Connection Error."""
8 |
9 | ESP_GATT_CONN_UNKNOWN = 0
10 | ESP_GATT_CONN_L2C_FAILURE = 1
11 | ESP_GATT_CONN_TIMEOUT = 0x08
12 | ESP_GATT_CONN_TERMINATE_PEER_USER = 0x13
13 | ESP_GATT_CONN_TERMINATE_LOCAL_HOST = 0x16
14 | ESP_GATT_CONN_FAIL_ESTABLISH = 0x3E
15 | ESP_GATT_CONN_LMP_TIMEOUT = 0x22
16 | ESP_GATT_ERROR = 0x85
17 | ESP_GATT_CONN_CONN_CANCEL = 0x0100
18 | ESP_GATT_CONN_NONE = 0x0101
19 |
20 |
21 | ESP_CONNECTION_ERROR_DESCRIPTION = {
22 | BLEConnectionError.ESP_GATT_CONN_UNKNOWN: "Connection failed for unknown reason",
23 | BLEConnectionError.ESP_GATT_CONN_L2C_FAILURE: "Connection failed due to L2CAP failure",
24 | BLEConnectionError.ESP_GATT_CONN_TIMEOUT: "Connection failed due to timeout",
25 | BLEConnectionError.ESP_GATT_CONN_TERMINATE_PEER_USER: "Connection terminated by peer user",
26 | BLEConnectionError.ESP_GATT_CONN_TERMINATE_LOCAL_HOST: "Connection terminated by local host",
27 | BLEConnectionError.ESP_GATT_CONN_FAIL_ESTABLISH: "Connection failed to establish",
28 | BLEConnectionError.ESP_GATT_CONN_LMP_TIMEOUT: "Connection failed due to LMP response timeout",
29 | BLEConnectionError.ESP_GATT_ERROR: "Connection failed due to GATT operation error",
30 | BLEConnectionError.ESP_GATT_CONN_CONN_CANCEL: "Connection cancelled",
31 | BLEConnectionError.ESP_GATT_CONN_NONE: "No connection to cancel",
32 | }
33 |
--------------------------------------------------------------------------------
/aioesphomeapi/client_base.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 |
3 | from ._frame_helper.base cimport APIFrameHelper
4 | from ._frame_helper.noise cimport APINoiseFrameHelper
5 | from ._frame_helper.plain_text cimport APIPlaintextFrameHelper
6 | from .connection cimport APIConnection, ConnectionParams
7 |
8 |
9 | cdef object create_eager_task
10 | cdef object APIConnectionError
11 |
12 | cdef dict SUBSCRIBE_STATES_RESPONSE_TYPES
13 |
14 | cdef bint TYPE_CHECKING
15 |
16 | cdef object CameraImageResponse, CameraState
17 |
18 | cdef object HomeassistantServiceCall
19 |
20 | cdef object BluetoothLEAdvertisement
21 |
22 | cdef object BluetoothDeviceConnectionResponse
23 |
24 | cdef str _stringify_or_none(object value)
25 |
26 | cdef class APIClientBase:
27 |
28 | cdef public set _background_tasks
29 | cdef public APIConnection _connection
30 | cdef public bint _debug_enabled
31 | cdef public object _loop
32 | cdef public ConnectionParams _params
33 | cdef public str cached_name
34 | cdef public str log_name
35 |
36 | cpdef _set_log_name(self)
37 |
38 | cpdef APIConnection _get_connection(self)
39 |
--------------------------------------------------------------------------------
/aioesphomeapi/client_base.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=unidiomatic-typecheck
2 | from __future__ import annotations
3 |
4 | import asyncio
5 | from collections.abc import Coroutine
6 | import logging
7 | from typing import TYPE_CHECKING, Any, Callable
8 |
9 | from google.protobuf import message
10 |
11 | from ._frame_helper.base import APIFrameHelper # noqa: F401
12 | from ._frame_helper.noise import APINoiseFrameHelper # noqa: F401
13 | from ._frame_helper.plain_text import APIPlaintextFrameHelper # noqa: F401
14 | from .api_pb2 import ( # type: ignore
15 | BluetoothConnectionsFreeResponse,
16 | BluetoothDeviceConnectionResponse,
17 | BluetoothGATTErrorResponse,
18 | BluetoothGATTGetServicesDoneResponse,
19 | BluetoothGATTGetServicesResponse,
20 | BluetoothGATTNotifyDataResponse,
21 | BluetoothGATTNotifyResponse,
22 | BluetoothGATTReadResponse,
23 | BluetoothGATTWriteResponse,
24 | BluetoothLEAdvertisementResponse,
25 | BluetoothScannerStateResponse,
26 | CameraImageResponse,
27 | HomeassistantServiceResponse,
28 | SubscribeHomeAssistantStateResponse,
29 | )
30 | from .connection import ConnectionParams
31 | from .core import APIConnectionError
32 | from .model import (
33 | APIVersion,
34 | BluetoothLEAdvertisement,
35 | BluetoothScannerStateResponse as BluetoothScannerStateResponseModel,
36 | CameraState,
37 | EntityState,
38 | HomeassistantServiceCall,
39 | )
40 | from .model_conversions import SUBSCRIBE_STATES_RESPONSE_TYPES
41 | from .util import build_log_name, create_eager_task
42 | from .zeroconf import ZeroconfInstanceType, ZeroconfManager
43 |
44 | if TYPE_CHECKING:
45 | from .connection import APIConnection
46 |
47 | _LOGGER = logging.getLogger(__name__)
48 |
49 | # We send a ping every 20 seconds, and the timeout ratio is 4.5x the
50 | # ping interval. This means that if we don't receive a ping for 90.0
51 | # seconds, we'll consider the connection dead and reconnect.
52 | #
53 | # This was chosen because the 20s is around the expected time for a
54 | # device to reboot and reconnect to wifi, and 90 seconds is the absolute
55 | # maximum time a device can take to respond when its behind + the WiFi
56 | # connection is poor.
57 | KEEP_ALIVE_FREQUENCY = 20.0
58 |
59 |
60 | def on_state_msg(
61 | on_state: Callable[[EntityState], None],
62 | image_stream: dict[int, list[bytes]],
63 | msg: message.Message,
64 | ) -> None:
65 | """Handle a state message."""
66 | msg_type = type(msg)
67 | if (cls := SUBSCRIBE_STATES_RESPONSE_TYPES.get(msg_type)) is not None:
68 | on_state(cls.from_pb(msg))
69 | elif msg_type is CameraImageResponse:
70 | if TYPE_CHECKING:
71 | assert isinstance(msg, CameraImageResponse)
72 | msg_key = msg.key
73 | data_parts: list[bytes] | None = image_stream.get(msg_key)
74 | if not data_parts:
75 | data_parts = []
76 | image_stream[msg_key] = data_parts
77 |
78 | data_parts.append(msg.data)
79 | if msg.done:
80 | # Return CameraState with the merged data
81 | image_data = b"".join(data_parts)
82 | del image_stream[msg_key]
83 | on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg]
84 |
85 |
86 | def on_home_assistant_service_response(
87 | on_service_call: Callable[[HomeassistantServiceCall], None],
88 | msg: HomeassistantServiceResponse,
89 | ) -> None:
90 | on_service_call(HomeassistantServiceCall.from_pb(msg))
91 |
92 |
93 | def on_bluetooth_le_advertising_response(
94 | on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None],
95 | msg: BluetoothLEAdvertisementResponse,
96 | ) -> None:
97 | on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc]
98 |
99 |
100 | def on_bluetooth_connections_free_response(
101 | on_bluetooth_connections_free_update: Callable[[int, int, list[int]], None],
102 | msg: BluetoothConnectionsFreeResponse,
103 | ) -> None:
104 | on_bluetooth_connections_free_update(msg.free, msg.limit, list(msg.allocated))
105 |
106 |
107 | def on_bluetooth_gatt_notify_data_response(
108 | address: int,
109 | handle: int,
110 | on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
111 | msg: BluetoothGATTNotifyDataResponse,
112 | ) -> None:
113 | """Handle a BluetoothGATTNotifyDataResponse message."""
114 | if address == msg.address and handle == msg.handle:
115 | try:
116 | on_bluetooth_gatt_notify(handle, bytearray(msg.data))
117 | except Exception:
118 | _LOGGER.exception(
119 | "Unexpected error in Bluetooth GATT notify callback for address %s, handle %s",
120 | address,
121 | handle,
122 | )
123 |
124 |
125 | def on_bluetooth_scanner_state_response(
126 | on_bluetooth_scanner_state: Callable[[BluetoothScannerStateResponseModel], None],
127 | msg: BluetoothScannerStateResponse,
128 | ) -> None:
129 | on_bluetooth_scanner_state(BluetoothScannerStateResponseModel.from_pb(msg))
130 |
131 |
132 | def on_subscribe_home_assistant_state_response(
133 | on_state_sub: Callable[[str, str | None], None],
134 | on_state_request: Callable[[str, str | None], None] | None,
135 | msg: SubscribeHomeAssistantStateResponse,
136 | ) -> None:
137 | if on_state_request and msg.once:
138 | on_state_request(msg.entity_id, msg.attribute)
139 | else:
140 | on_state_sub(msg.entity_id, msg.attribute)
141 |
142 |
143 | def on_bluetooth_device_connection_response(
144 | connect_future: asyncio.Future[None],
145 | address: int,
146 | on_bluetooth_connection_state: Callable[[bool, int, int], None],
147 | msg: BluetoothDeviceConnectionResponse,
148 | ) -> None:
149 | """Handle a BluetoothDeviceConnectionResponse message.""" ""
150 | if address == msg.address:
151 | on_bluetooth_connection_state(msg.connected, msg.mtu, msg.error)
152 | # Resolve on ANY connection state since we do not want
153 | # to wait the whole timeout if the device disconnects
154 | # or we get an error.
155 | if not connect_future.done():
156 | connect_future.set_result(None)
157 |
158 |
159 | def on_bluetooth_handle_message(
160 | address: int,
161 | handle: int,
162 | msg: (
163 | BluetoothGATTErrorResponse
164 | | BluetoothGATTNotifyResponse
165 | | BluetoothGATTReadResponse
166 | | BluetoothGATTWriteResponse
167 | | BluetoothDeviceConnectionResponse
168 | ),
169 | ) -> bool:
170 | """Filter a Bluetooth message for an address and handle."""
171 | if type(msg) is BluetoothDeviceConnectionResponse:
172 | return bool(msg.address == address)
173 | return bool(msg.address == address and msg.handle == handle)
174 |
175 |
176 | def on_bluetooth_message_types(
177 | address: int,
178 | msg_types: tuple[type[message.Message], ...],
179 | msg: (
180 | BluetoothGATTErrorResponse
181 | | BluetoothGATTNotifyResponse
182 | | BluetoothGATTReadResponse
183 | | BluetoothGATTWriteResponse
184 | | BluetoothDeviceConnectionResponse
185 | | BluetoothGATTGetServicesResponse
186 | | BluetoothGATTGetServicesDoneResponse
187 | | BluetoothGATTErrorResponse
188 | ),
189 | ) -> bool:
190 | """Filter Bluetooth messages of a specific type and address."""
191 | return type(msg) in msg_types and bool(msg.address == address)
192 |
193 |
194 | str_ = str
195 |
196 |
197 | def _stringify_or_none(value: str_ | None) -> str | None:
198 | """Convert a string like object to a str or None.
199 |
200 | The noise_psk is sometimes passed into
201 | the client as an Estr, but we want to pass it
202 | to the API as a string or None.
203 | """
204 | return None if value is None else str(value)
205 |
206 |
207 | class APIClientBase:
208 | """Base client for ESPHome API clients."""
209 |
210 | __slots__ = (
211 | "_background_tasks",
212 | "_connection",
213 | "_debug_enabled",
214 | "_loop",
215 | "_params",
216 | "cached_name",
217 | "log_name",
218 | )
219 |
220 | def __init__(
221 | self,
222 | address: str_, # allow subclass str
223 | port: int,
224 | password: str_ | None,
225 | *,
226 | client_info: str_ = "aioesphomeapi",
227 | keepalive: float = KEEP_ALIVE_FREQUENCY,
228 | zeroconf_instance: ZeroconfInstanceType | None = None,
229 | noise_psk: str_ | None = None,
230 | expected_name: str_ | None = None,
231 | addresses: list[str_] | None = None,
232 | expected_mac: str_ | None = None,
233 | ) -> None:
234 | """Create a client, this object is shared across sessions.
235 |
236 | :param address: The address to connect to; for example an IP address
237 | or .local name for mDNS lookup.
238 | :param port: The port to connect to
239 | :param password: Optional password to send to the device for authentication
240 | :param client_info: User Agent string to send.
241 | :param keepalive: The keepalive time in seconds (ping interval) for detecting stale connections.
242 | Every keepalive seconds a ping is sent, if no pong is received the connection is closed.
243 | :param zeroconf_instance: Pass a zeroconf instance to use if an mDNS lookup is necessary.
244 | :param noise_psk: Encryption preshared key for noise transport encrypted sessions.
245 | :param expected_name: Require the devices name to match the given expected name.
246 | Can be used to prevent accidentally connecting to a different device if
247 | IP passed as address but DHCP reassigned IP.
248 | :param addresses: Optional list of IP addresses to connect to which takes
249 | precedence over the address parameter. This is most commonly used when
250 | the device has dual stack IPv4 and IPv6 addresses and you do not know
251 | which one to connect to.
252 | :param expected_mac: Optional MAC address to check against the device.
253 | The format should be lower case without : or - separators.
254 | Example: 00:aa:22:33:44:55 -> 00aa22334455
255 | """
256 | self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
257 | self._params = ConnectionParams(
258 | addresses=addresses if addresses else [str(address)],
259 | port=port,
260 | password=password,
261 | client_info=client_info,
262 | keepalive=keepalive,
263 | zeroconf_manager=ZeroconfManager(zeroconf_instance),
264 | # treat empty '' psk string as missing (like password)
265 | noise_psk=_stringify_or_none(noise_psk) or None,
266 | expected_name=_stringify_or_none(expected_name) or None,
267 | expected_mac=_stringify_or_none(expected_mac) or None,
268 | )
269 | self._connection: APIConnection | None = None
270 | self.cached_name: str | None = None
271 | self._background_tasks: set[asyncio.Task[Any]] = set()
272 | self._loop = asyncio.get_running_loop()
273 | self._set_log_name()
274 |
275 | def set_debug(self, enabled: bool) -> None:
276 | """Enable debug logging."""
277 | self._debug_enabled = enabled
278 | if self._connection is not None:
279 | self._connection.set_debug(enabled)
280 |
281 | @property
282 | def zeroconf_manager(self) -> ZeroconfManager:
283 | return self._params.zeroconf_manager
284 |
285 | @property
286 | def expected_name(self) -> str | None:
287 | return self._params.expected_name
288 |
289 | @expected_name.setter
290 | def expected_name(self, value: str | None) -> None:
291 | self._params.expected_name = value
292 |
293 | @property
294 | def address(self) -> str:
295 | return self._params.addresses[0]
296 |
297 | @property
298 | def api_version(self) -> APIVersion | None:
299 | if self._connection is None:
300 | return None
301 | return self._connection.api_version
302 |
303 | def _set_log_name(self) -> None:
304 | """Set the log name of the device."""
305 | connected_address: str | None = None
306 | if self._connection is not None and self._connection.connected_address:
307 | connected_address = self._connection.connected_address
308 | self.log_name = build_log_name(
309 | self.cached_name,
310 | self._params.addresses,
311 | connected_address,
312 | )
313 | if self._connection is not None:
314 | self._connection.set_log_name(self.log_name)
315 |
316 | def _set_name_from_device(self, name: str_) -> None:
317 | """Set the name from a DeviceInfo message."""
318 | self.cached_name = str(name) # May be Estr from esphome
319 | self._set_log_name()
320 |
321 | def set_cached_name_if_unset(self, name: str_) -> None:
322 | """Set the cached name of the device if not set."""
323 | if not self.cached_name:
324 | self._set_name_from_device(name)
325 |
326 | def _create_background_task(self, coro: Coroutine[Any, Any, None]) -> None:
327 | """Create a background task and add it to the background tasks set."""
328 | task = create_eager_task(coro)
329 | self._background_tasks.add(task)
330 | task.add_done_callback(self._background_tasks.discard)
331 |
332 | def _get_connection(self) -> APIConnection:
333 | if self._connection is None:
334 | raise APIConnectionError(f"Not connected to {self.log_name}!")
335 | if not self._connection.is_connected:
336 | raise APIConnectionError(
337 | f"Authenticated connection not ready yet for {self.log_name}; "
338 | f"current state is {self._connection.connection_state}!"
339 | )
340 | return self._connection
341 |
--------------------------------------------------------------------------------
/aioesphomeapi/connection.pxd:
--------------------------------------------------------------------------------
1 | import cython
2 |
3 | from ._frame_helper.base cimport APIFrameHelper
4 |
5 |
6 | cdef dict MESSAGE_TYPE_TO_PROTO
7 | cdef dict PROTO_TO_MESSAGE_TYPE
8 |
9 | cdef set OPEN_STATES
10 |
11 | cdef float KEEP_ALIVE_TIMEOUT_RATIO
12 | cdef object HANDSHAKE_TIMEOUT
13 |
14 | cdef bint TYPE_CHECKING
15 | cdef bint _WIN32
16 |
17 | cdef object WRITE_EXCEPTIONS
18 |
19 | cdef object DISCONNECT_REQUEST_MESSAGE
20 | cdef tuple DISCONNECT_RESPONSE_MESSAGES
21 | cdef tuple PING_REQUEST_MESSAGES
22 | cdef tuple PING_RESPONSE_MESSAGES
23 | cdef object NO_PASSWORD_CONNECT_REQUEST
24 |
25 | cdef object asyncio_timeout
26 | cdef object CancelledError
27 | cdef object asyncio_TimeoutError
28 |
29 | cdef object ConnectRequest, ConnectResponse
30 | cdef object DisconnectRequest
31 | cdef object PingRequest
32 | cdef object GetTimeRequest, GetTimeResponse
33 | cdef object HelloRequest, HelloResponse
34 |
35 | cdef object APIVersion
36 |
37 | cdef object partial
38 |
39 | cdef object hr
40 |
41 | cdef object CONNECT_AND_SETUP_TIMEOUT, CONNECT_REQUEST_TIMEOUT
42 |
43 | cdef object APIConnectionError
44 | cdef object BadNameAPIError
45 | cdef object HandshakeAPIError
46 | cdef object PingFailedAPIError
47 | cdef object ReadFailedAPIError
48 | cdef object TimeoutAPIError
49 | cdef object SocketAPIError
50 | cdef object InvalidAuthAPIError
51 | cdef object SocketClosedAPIError
52 |
53 | cdef object astuple
54 |
55 | cdef object CONNECTION_STATE_INITIALIZED
56 | cdef object CONNECTION_STATE_HOST_RESOLVED
57 | cdef object CONNECTION_STATE_SOCKET_OPENED
58 | cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
59 | cdef object CONNECTION_STATE_CONNECTED
60 | cdef object CONNECTION_STATE_CLOSED
61 |
62 | cdef object make_hello_request
63 |
64 | cpdef void handle_timeout(object fut)
65 | cpdef void handle_complex_message(
66 | object fut,
67 | list responses,
68 | object do_append,
69 | object do_stop,
70 | object resp,
71 | )
72 |
73 | cdef object _handle_timeout
74 | cdef object _handle_complex_message
75 |
76 | cdef tuple MESSAGE_NUMBER_TO_PROTO
77 |
78 |
79 | @cython.dataclasses.dataclass
80 | cdef class ConnectionParams:
81 |
82 | cdef public list addresses
83 | cdef public object port
84 | cdef public object password
85 | cdef public object client_info
86 | cdef public object keepalive
87 | cdef public object zeroconf_manager
88 | cdef public object noise_psk
89 | cdef public object expected_name
90 | cdef public object expected_mac
91 |
92 |
93 | cdef class APIConnection:
94 |
95 | cdef ConnectionParams _params
96 | cdef public object on_stop
97 | cdef public object _socket
98 | cdef public APIFrameHelper _frame_helper
99 | cdef public object api_version
100 | cdef public object connection_state
101 | cdef public dict _message_handlers
102 | cdef public str log_name
103 | cdef set _read_exception_futures
104 | cdef object _ping_timer
105 | cdef object _pong_timer
106 | cdef float _keep_alive_interval
107 | cdef float _keep_alive_timeout
108 | cdef object _resolve_host_future
109 | cdef object _start_connect_future
110 | cdef object _finish_connect_future
111 | cdef public Exception _fatal_exception
112 | cdef bint _expected_disconnect
113 | cdef object _loop
114 | cdef bint _send_pending_ping
115 | cdef public bint is_connected
116 | cdef bint _handshake_complete
117 | cdef bint _debug_enabled
118 | cdef public str received_name
119 | cdef public str connected_address
120 | cdef list _addrs_info
121 |
122 | cpdef void send_message(self, object msg) except *
123 |
124 | @cython.locals(msg_type=tuple)
125 | cdef void send_messages(self, tuple messages) except *
126 |
127 | @cython.locals(handlers=set, handlers_copy=set, klass_merge=tuple)
128 | cpdef void process_packet(
129 | self,
130 | unsigned int msg_type_proto,
131 | object data
132 | ) except *
133 |
134 | cdef void _async_cancel_pong_timer(self) except *
135 |
136 | cdef void _async_schedule_keep_alive(self, object now) except *
137 |
138 | cdef void _cleanup(self) except *
139 |
140 | cpdef set_log_name(self, str name)
141 |
142 | cdef _make_connect_request(self)
143 |
144 | cdef void _process_hello_resp(self, object resp) except *
145 |
146 | cdef void _process_login_response(self, object hello_response) except *
147 |
148 | cdef void _set_connection_state(self, object state) except *
149 |
150 | cpdef void report_fatal_error(self, Exception err) except *
151 |
152 | @cython.locals(handlers=set)
153 | cdef void _add_message_callback_without_remove(
154 | self,
155 | object on_message,
156 | tuple msg_types
157 | ) except *
158 |
159 | cpdef add_message_callback(self, object on_message, tuple msg_types)
160 |
161 | @cython.locals(handlers=set)
162 | cpdef void _remove_message_callback(
163 | self,
164 | object on_message,
165 | tuple msg_types
166 | ) except *
167 |
168 | cpdef void _handle_disconnect_request_internal(self, object msg) except *
169 |
170 | cpdef void _handle_ping_request_internal(self, object msg) except *
171 |
172 | cpdef void _handle_get_time_request_internal(self, object msg) except *
173 |
174 | cdef void _set_fatal_exception_if_unset(self, Exception err) except *
175 |
176 | cdef void _register_internal_message_handlers(self) except *
177 |
178 | cdef void _increase_recv_buffer_size(self) except *
179 |
180 | cdef void _set_start_connect_future(self) except *
181 |
182 | cdef void _set_finish_connect_future(self) except *
183 |
--------------------------------------------------------------------------------
/aioesphomeapi/core.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | import re
5 |
6 | from aioesphomeapi.model import BluetoothGATTError
7 |
8 | from .api_pb2 import ( # type: ignore
9 | AlarmControlPanelCommandRequest,
10 | AlarmControlPanelStateResponse,
11 | BinarySensorStateResponse,
12 | BluetoothConnectionsFreeResponse,
13 | BluetoothDeviceClearCacheResponse,
14 | BluetoothDeviceConnectionResponse,
15 | BluetoothDevicePairingResponse,
16 | BluetoothDeviceRequest,
17 | BluetoothDeviceUnpairingResponse,
18 | BluetoothGATTErrorResponse,
19 | BluetoothGATTGetServicesDoneResponse,
20 | BluetoothGATTGetServicesRequest,
21 | BluetoothGATTGetServicesResponse,
22 | BluetoothGATTNotifyDataResponse,
23 | BluetoothGATTNotifyRequest,
24 | BluetoothGATTNotifyResponse,
25 | BluetoothGATTReadDescriptorRequest,
26 | BluetoothGATTReadRequest,
27 | BluetoothGATTReadResponse,
28 | BluetoothGATTWriteDescriptorRequest,
29 | BluetoothGATTWriteRequest,
30 | BluetoothGATTWriteResponse,
31 | BluetoothLEAdvertisementResponse,
32 | BluetoothLERawAdvertisementsResponse,
33 | BluetoothScannerSetModeRequest,
34 | BluetoothScannerStateResponse,
35 | ButtonCommandRequest,
36 | CameraImageRequest,
37 | CameraImageResponse,
38 | ClimateCommandRequest,
39 | ClimateStateResponse,
40 | ConnectRequest,
41 | ConnectResponse,
42 | CoverCommandRequest,
43 | CoverStateResponse,
44 | DateCommandRequest,
45 | DateStateResponse,
46 | DateTimeCommandRequest,
47 | DateTimeStateResponse,
48 | DeviceInfoRequest,
49 | DeviceInfoResponse,
50 | DisconnectRequest,
51 | DisconnectResponse,
52 | EventResponse,
53 | ExecuteServiceRequest,
54 | FanCommandRequest,
55 | FanStateResponse,
56 | GetTimeRequest,
57 | GetTimeResponse,
58 | HelloRequest,
59 | HelloResponse,
60 | HomeassistantServiceResponse,
61 | HomeAssistantStateResponse,
62 | LightCommandRequest,
63 | LightStateResponse,
64 | ListEntitiesAlarmControlPanelResponse,
65 | ListEntitiesBinarySensorResponse,
66 | ListEntitiesButtonResponse,
67 | ListEntitiesCameraResponse,
68 | ListEntitiesClimateResponse,
69 | ListEntitiesCoverResponse,
70 | ListEntitiesDateResponse,
71 | ListEntitiesDateTimeResponse,
72 | ListEntitiesDoneResponse,
73 | ListEntitiesEventResponse,
74 | ListEntitiesFanResponse,
75 | ListEntitiesLightResponse,
76 | ListEntitiesLockResponse,
77 | ListEntitiesMediaPlayerResponse,
78 | ListEntitiesNumberResponse,
79 | ListEntitiesRequest,
80 | ListEntitiesSelectResponse,
81 | ListEntitiesSensorResponse,
82 | ListEntitiesServicesResponse,
83 | ListEntitiesSirenResponse,
84 | ListEntitiesSwitchResponse,
85 | ListEntitiesTextResponse,
86 | ListEntitiesTextSensorResponse,
87 | ListEntitiesTimeResponse,
88 | ListEntitiesUpdateResponse,
89 | ListEntitiesValveResponse,
90 | LockCommandRequest,
91 | LockStateResponse,
92 | MediaPlayerCommandRequest,
93 | MediaPlayerStateResponse,
94 | NoiseEncryptionSetKeyRequest,
95 | NoiseEncryptionSetKeyResponse,
96 | NumberCommandRequest,
97 | NumberStateResponse,
98 | PingRequest,
99 | PingResponse,
100 | SelectCommandRequest,
101 | SelectStateResponse,
102 | SensorStateResponse,
103 | SirenCommandRequest,
104 | SirenStateResponse,
105 | SubscribeBluetoothConnectionsFreeRequest,
106 | SubscribeBluetoothLEAdvertisementsRequest,
107 | SubscribeHomeassistantServicesRequest,
108 | SubscribeHomeAssistantStateResponse,
109 | SubscribeHomeAssistantStatesRequest,
110 | SubscribeLogsRequest,
111 | SubscribeLogsResponse,
112 | SubscribeStatesRequest,
113 | SubscribeVoiceAssistantRequest,
114 | SwitchCommandRequest,
115 | SwitchStateResponse,
116 | TextCommandRequest,
117 | TextSensorStateResponse,
118 | TextStateResponse,
119 | TimeCommandRequest,
120 | TimeStateResponse,
121 | UnsubscribeBluetoothLEAdvertisementsRequest,
122 | UpdateCommandRequest,
123 | UpdateStateResponse,
124 | ValveCommandRequest,
125 | ValveStateResponse,
126 | VoiceAssistantAnnounceFinished,
127 | VoiceAssistantAnnounceRequest,
128 | VoiceAssistantAudio,
129 | VoiceAssistantConfigurationRequest,
130 | VoiceAssistantConfigurationResponse,
131 | VoiceAssistantEventResponse,
132 | VoiceAssistantRequest,
133 | VoiceAssistantResponse,
134 | VoiceAssistantSetConfiguration,
135 | VoiceAssistantTimerEventResponse,
136 | )
137 |
138 | TWO_CHAR = re.compile(r".{2}")
139 |
140 | # Taken from esp_gatt_status_t in esp_gatt_defs.h
141 | ESPHOME_GATT_ERRORS = {
142 | -1: "Not connected", # Custom ESPHome error
143 | 1: "Invalid handle",
144 | 2: "Read not permitted",
145 | 3: "Write not permitted",
146 | 4: "Invalid PDU",
147 | 5: "Insufficient authentication",
148 | 6: "Request not supported",
149 | 7: "Invalid offset",
150 | 8: "Insufficient authorization",
151 | 9: "Prepare queue full",
152 | 10: "Attribute not found",
153 | 11: "Attribute not long",
154 | 12: "Insufficient key size",
155 | 13: "Invalid attribute length",
156 | 14: "Unlikely error",
157 | 15: "Insufficient encryption",
158 | 16: "Unsupported group type",
159 | 17: "Insufficient resources",
160 | 128: "Application error",
161 | 129: "Internal error",
162 | 130: "Wrong state",
163 | 131: "Database full",
164 | 132: "Busy",
165 | 133: "Error",
166 | 134: "Command started",
167 | 135: "Illegal parameter",
168 | 136: "Pending",
169 | 137: "Auth fail",
170 | 138: "More",
171 | 139: "Invalid configuration",
172 | 140: "Service started",
173 | 141: "Encrypted no mitm",
174 | 142: "Not encrypted",
175 | 143: "Congested",
176 | 144: "Duplicate registration",
177 | 145: "Already open",
178 | 146: "Cancel",
179 | 224: "Stack RSP",
180 | 225: "App RSP",
181 | 239: "Unknown error",
182 | 253: "CCC config error",
183 | 254: "Procedure already in progress",
184 | 255: "Out of range",
185 | }
186 |
187 |
188 | class APIConnectionError(Exception):
189 | pass
190 |
191 |
192 | class APIConnectionCancelledError(APIConnectionError):
193 | pass
194 |
195 |
196 | class InvalidAuthAPIError(APIConnectionError):
197 | pass
198 |
199 |
200 | class ResolveAPIError(APIConnectionError):
201 | """Raised when a resolve error occurs."""
202 |
203 |
204 | class ResolveTimeoutAPIError(ResolveAPIError, asyncio.TimeoutError):
205 | """Raised when a resolve timeout occurs."""
206 |
207 |
208 | class ProtocolAPIError(APIConnectionError):
209 | pass
210 |
211 |
212 | class RequiresEncryptionAPIError(ProtocolAPIError):
213 | pass
214 |
215 |
216 | class SocketAPIError(APIConnectionError):
217 | pass
218 |
219 |
220 | class SocketClosedAPIError(SocketAPIError):
221 | pass
222 |
223 |
224 | class HandshakeAPIError(APIConnectionError):
225 | pass
226 |
227 |
228 | class ConnectionNotEstablishedAPIError(APIConnectionError):
229 | pass
230 |
231 |
232 | class BadNameAPIError(APIConnectionError):
233 | """Raised when a name received from the remote but does not much the expected name."""
234 |
235 | def __init__(self, msg: str, received_name: str) -> None:
236 | super().__init__(f"{msg}: received_name={received_name}")
237 | self.received_name = received_name
238 |
239 |
240 | class BadMACAddressAPIError(APIConnectionError):
241 | """Raised when a MAC address received from the remote but does not much the expected MAC address."""
242 |
243 | def __init__(self, msg: str, received_name: str, received_mac: str) -> None:
244 | super().__init__(
245 | f"{msg}: received_name={received_name}, received_mac={received_mac}"
246 | )
247 | self.received_name = received_name
248 | self.received_mac = received_mac
249 |
250 |
251 | class InvalidEncryptionKeyAPIError(HandshakeAPIError):
252 | """Raised when the encryption key is invalid."""
253 |
254 | def __init__(
255 | self,
256 | msg: str | None = None,
257 | received_name: str | None = None,
258 | received_mac: str | None = None,
259 | ) -> None:
260 | super().__init__(
261 | f"{msg}: received_name={received_name}, received_mac={received_mac}"
262 | )
263 | self.received_name = received_name
264 | self.received_mac = received_mac
265 |
266 |
267 | class EncryptionErrorAPIError(InvalidEncryptionKeyAPIError):
268 | """Raised when an encryption error occurs after handshake."""
269 |
270 |
271 | class EncryptionHelloAPIError(HandshakeAPIError):
272 | """Raised when an encryption error occurs during hello."""
273 |
274 |
275 | class EncryptionPlaintextAPIError(HandshakeAPIError):
276 | """Raised when the ESP is using plaintext during noise handshake."""
277 |
278 |
279 | class PingFailedAPIError(APIConnectionError):
280 | """Raised when a ping fails."""
281 |
282 |
283 | class TimeoutAPIError(APIConnectionError):
284 | """Raised when a timeout occurs."""
285 |
286 |
287 | class ReadFailedAPIError(APIConnectionError):
288 | """Raised when a read fails."""
289 |
290 |
291 | class UnhandledAPIConnectionError(APIConnectionError):
292 | """Raised when an unhandled error occurs."""
293 |
294 |
295 | class BluetoothConnectionDroppedError(APIConnectionError):
296 | """Raised when a Bluetooth connection is dropped."""
297 |
298 |
299 | def to_human_readable_address(address: int) -> str:
300 | """Convert a MAC address to a human readable format."""
301 | return ":".join(TWO_CHAR.findall(f"{address:012X}"))
302 |
303 |
304 | def to_human_readable_gatt_error(error: int) -> str:
305 | """Convert a GATT error to a human readable format."""
306 | return ESPHOME_GATT_ERRORS.get(error, "Unknown error")
307 |
308 |
309 | class BluetoothGATTAPIError(APIConnectionError):
310 | def __init__(self, error: BluetoothGATTError) -> None:
311 | super().__init__(
312 | f"Bluetooth GATT Error "
313 | f"address={to_human_readable_address(error.address)} "
314 | f"handle={error.handle} "
315 | f"error={error.error} "
316 | f"description={to_human_readable_gatt_error(error.error)}"
317 | )
318 | self.error = error
319 |
320 |
321 | MESSAGE_TYPE_TO_PROTO = {
322 | 1: HelloRequest,
323 | 2: HelloResponse,
324 | 3: ConnectRequest,
325 | 4: ConnectResponse,
326 | 5: DisconnectRequest,
327 | 6: DisconnectResponse,
328 | 7: PingRequest,
329 | 8: PingResponse,
330 | 9: DeviceInfoRequest,
331 | 10: DeviceInfoResponse,
332 | 11: ListEntitiesRequest,
333 | 12: ListEntitiesBinarySensorResponse,
334 | 13: ListEntitiesCoverResponse,
335 | 14: ListEntitiesFanResponse,
336 | 15: ListEntitiesLightResponse,
337 | 16: ListEntitiesSensorResponse,
338 | 17: ListEntitiesSwitchResponse,
339 | 18: ListEntitiesTextSensorResponse,
340 | 19: ListEntitiesDoneResponse,
341 | 20: SubscribeStatesRequest,
342 | 21: BinarySensorStateResponse,
343 | 22: CoverStateResponse,
344 | 23: FanStateResponse,
345 | 24: LightStateResponse,
346 | 25: SensorStateResponse,
347 | 26: SwitchStateResponse,
348 | 27: TextSensorStateResponse,
349 | 28: SubscribeLogsRequest,
350 | 29: SubscribeLogsResponse,
351 | 30: CoverCommandRequest,
352 | 31: FanCommandRequest,
353 | 32: LightCommandRequest,
354 | 33: SwitchCommandRequest,
355 | 34: SubscribeHomeassistantServicesRequest,
356 | 35: HomeassistantServiceResponse,
357 | 36: GetTimeRequest,
358 | 37: GetTimeResponse,
359 | 38: SubscribeHomeAssistantStatesRequest,
360 | 39: SubscribeHomeAssistantStateResponse,
361 | 40: HomeAssistantStateResponse,
362 | 41: ListEntitiesServicesResponse,
363 | 42: ExecuteServiceRequest,
364 | 43: ListEntitiesCameraResponse,
365 | 44: CameraImageResponse,
366 | 45: CameraImageRequest,
367 | 46: ListEntitiesClimateResponse,
368 | 47: ClimateStateResponse,
369 | 48: ClimateCommandRequest,
370 | 49: ListEntitiesNumberResponse,
371 | 50: NumberStateResponse,
372 | 51: NumberCommandRequest,
373 | 52: ListEntitiesSelectResponse,
374 | 53: SelectStateResponse,
375 | 54: SelectCommandRequest,
376 | 55: ListEntitiesSirenResponse,
377 | 56: SirenStateResponse,
378 | 57: SirenCommandRequest,
379 | 58: ListEntitiesLockResponse,
380 | 59: LockStateResponse,
381 | 60: LockCommandRequest,
382 | 61: ListEntitiesButtonResponse,
383 | 62: ButtonCommandRequest,
384 | 63: ListEntitiesMediaPlayerResponse,
385 | 64: MediaPlayerStateResponse,
386 | 65: MediaPlayerCommandRequest,
387 | 66: SubscribeBluetoothLEAdvertisementsRequest,
388 | 67: BluetoothLEAdvertisementResponse,
389 | 68: BluetoothDeviceRequest,
390 | 69: BluetoothDeviceConnectionResponse,
391 | 70: BluetoothGATTGetServicesRequest,
392 | 71: BluetoothGATTGetServicesResponse,
393 | 72: BluetoothGATTGetServicesDoneResponse,
394 | 73: BluetoothGATTReadRequest,
395 | 74: BluetoothGATTReadResponse,
396 | 75: BluetoothGATTWriteRequest,
397 | 76: BluetoothGATTReadDescriptorRequest,
398 | 77: BluetoothGATTWriteDescriptorRequest,
399 | 78: BluetoothGATTNotifyRequest,
400 | 79: BluetoothGATTNotifyDataResponse,
401 | 80: SubscribeBluetoothConnectionsFreeRequest,
402 | 81: BluetoothConnectionsFreeResponse,
403 | 82: BluetoothGATTErrorResponse,
404 | 83: BluetoothGATTWriteResponse,
405 | 84: BluetoothGATTNotifyResponse,
406 | 85: BluetoothDevicePairingResponse,
407 | 86: BluetoothDeviceUnpairingResponse,
408 | 87: UnsubscribeBluetoothLEAdvertisementsRequest,
409 | 88: BluetoothDeviceClearCacheResponse,
410 | 89: SubscribeVoiceAssistantRequest,
411 | 90: VoiceAssistantRequest,
412 | 91: VoiceAssistantResponse,
413 | 92: VoiceAssistantEventResponse,
414 | 93: BluetoothLERawAdvertisementsResponse,
415 | 94: ListEntitiesAlarmControlPanelResponse,
416 | 95: AlarmControlPanelStateResponse,
417 | 96: AlarmControlPanelCommandRequest,
418 | 97: ListEntitiesTextResponse,
419 | 98: TextStateResponse,
420 | 99: TextCommandRequest,
421 | 100: ListEntitiesDateResponse,
422 | 101: DateStateResponse,
423 | 102: DateCommandRequest,
424 | 103: ListEntitiesTimeResponse,
425 | 104: TimeStateResponse,
426 | 105: TimeCommandRequest,
427 | 106: VoiceAssistantAudio,
428 | 107: ListEntitiesEventResponse,
429 | 108: EventResponse,
430 | 109: ListEntitiesValveResponse,
431 | 110: ValveStateResponse,
432 | 111: ValveCommandRequest,
433 | 112: ListEntitiesDateTimeResponse,
434 | 113: DateTimeStateResponse,
435 | 114: DateTimeCommandRequest,
436 | 115: VoiceAssistantTimerEventResponse,
437 | 116: ListEntitiesUpdateResponse,
438 | 117: UpdateStateResponse,
439 | 118: UpdateCommandRequest,
440 | 119: VoiceAssistantAnnounceRequest,
441 | 120: VoiceAssistantAnnounceFinished,
442 | 121: VoiceAssistantConfigurationRequest,
443 | 122: VoiceAssistantConfigurationResponse,
444 | 123: VoiceAssistantSetConfiguration,
445 | 124: NoiseEncryptionSetKeyRequest,
446 | 125: NoiseEncryptionSetKeyResponse,
447 | 126: BluetoothScannerStateResponse,
448 | 127: BluetoothScannerSetModeRequest,
449 | }
450 |
451 | MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values())
452 |
--------------------------------------------------------------------------------
/aioesphomeapi/discover.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 |
5 | # Helper script and aioesphomeapi to discover api devices
6 | import asyncio
7 | import contextlib
8 | import logging
9 | import sys
10 |
11 | from zeroconf import IPVersion, ServiceStateChange, Zeroconf
12 | from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
13 |
14 | FORMAT = "{: <7}|{: <32}|{: <15}|{: <12}|{: <16}|{: <10}|{: <32}"
15 | COLUMN_NAMES = ("Status", "Name", "Address", "MAC", "Version", "Platform", "Board")
16 | UNKNOWN = "unknown"
17 |
18 |
19 | def decode_bytes_or_unknown(data: str | bytes | None) -> str:
20 | """Decode bytes or return unknown."""
21 | if data is None:
22 | return UNKNOWN
23 | if isinstance(data, bytes):
24 | return data.decode()
25 | return data
26 |
27 |
28 | def async_service_update(
29 | zeroconf: Zeroconf,
30 | service_type: str,
31 | name: str,
32 | state_change: ServiceStateChange,
33 | ) -> None:
34 | """Service state changed."""
35 | short_name = name.partition(".")[0]
36 | state = "OFFLINE" if state_change is ServiceStateChange.Removed else "ONLINE"
37 | info = AsyncServiceInfo(service_type, name)
38 | info.load_from_cache(zeroconf)
39 | properties = info.properties
40 | mac = decode_bytes_or_unknown(properties.get(b"mac"))
41 | version = decode_bytes_or_unknown(properties.get(b"version"))
42 | platform = decode_bytes_or_unknown(properties.get(b"platform"))
43 | board = decode_bytes_or_unknown(properties.get(b"board"))
44 | address = ""
45 | if addresses := info.ip_addresses_by_version(IPVersion.V4Only):
46 | address = str(addresses[0])
47 |
48 | print(FORMAT.format(state, short_name, address, mac, version, platform, board))
49 |
50 |
51 | async def main(argv: list[str]) -> None:
52 | parser = argparse.ArgumentParser("aioesphomeapi-discover")
53 | parser.add_argument("-v", "--verbose", action="store_true")
54 | args = parser.parse_args(argv[1:])
55 | logging.basicConfig(
56 | format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
57 | level=logging.DEBUG if args.verbose else logging.INFO,
58 | datefmt="%Y-%m-%d %H:%M:%S",
59 | )
60 | if args.verbose:
61 | logging.getLogger("zeroconf").setLevel(logging.DEBUG)
62 |
63 | aiozc = AsyncZeroconf()
64 | browser = AsyncServiceBrowser(
65 | aiozc.zeroconf, "_esphomelib._tcp.local.", handlers=[async_service_update]
66 | )
67 | print(FORMAT.format(*COLUMN_NAMES))
68 | print("-" * 120)
69 |
70 | try:
71 | await asyncio.Event().wait()
72 | finally:
73 | await browser.async_cancel()
74 | await aiozc.async_close()
75 |
76 |
77 | def cli_entry_point() -> None:
78 | """Run the CLI."""
79 | with contextlib.suppress(KeyboardInterrupt):
80 | asyncio.run(main(sys.argv))
81 |
82 |
83 | if __name__ == "__main__":
84 | cli_entry_point()
85 | sys.exit(0)
86 |
--------------------------------------------------------------------------------
/aioesphomeapi/host_resolver.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from collections import defaultdict
5 | from collections.abc import Coroutine
6 | from contextlib import suppress
7 | from dataclasses import dataclass
8 | from ipaddress import IPv4Address, IPv6Address, ip_address
9 | import itertools
10 | import logging
11 | import socket
12 | from typing import TYPE_CHECKING, Any, cast
13 |
14 | from zeroconf import IPVersion
15 | from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
16 |
17 | from .core import ResolveAPIError, ResolveTimeoutAPIError
18 | from .util import (
19 | address_is_local,
20 | asyncio_timeout,
21 | create_eager_task,
22 | host_is_name_part,
23 | )
24 | from .zeroconf import ZeroconfManager
25 |
26 | _LOGGER = logging.getLogger(__name__)
27 |
28 |
29 | SERVICE_TYPE = "_esphomelib._tcp.local."
30 | RESOLVE_TIMEOUT = 30.0
31 |
32 |
33 | @dataclass(frozen=True)
34 | class Sockaddr:
35 | """Base socket address."""
36 |
37 | address: str
38 | port: int
39 |
40 |
41 | @dataclass(frozen=True)
42 | class IPv4Sockaddr(Sockaddr):
43 | """IPv4 socket address."""
44 |
45 |
46 | @dataclass(frozen=True)
47 | class IPv6Sockaddr(Sockaddr):
48 | """IPv6 socket address."""
49 |
50 | flowinfo: int
51 | scope_id: int
52 |
53 |
54 | @dataclass(frozen=True)
55 | class AddrInfo:
56 | family: int
57 | type: int
58 | proto: int
59 | sockaddr: IPv4Sockaddr | IPv6Sockaddr
60 |
61 |
62 | async def _async_zeroconf_get_service_info(
63 | aiozc: AsyncZeroconf,
64 | short_host: str,
65 | timeout: float,
66 | ) -> AsyncServiceInfo:
67 | info = _make_service_info_for_short_host(short_host)
68 | try:
69 | await info.async_request(aiozc.zeroconf, int(timeout * 1000))
70 | except Exception as exc:
71 | raise ResolveAPIError(
72 | f"Error resolving mDNS {short_host} via mDNS: {exc}"
73 | ) from exc
74 | return info
75 |
76 |
77 | def _scope_id_to_int(value: str | None) -> int:
78 | """Convert a scope id to int if possible."""
79 | if value is None:
80 | return 0
81 | try:
82 | return int(value)
83 | except ValueError:
84 | return 0
85 |
86 |
87 | def _make_service_info_for_short_host(host: str) -> AsyncServiceInfo:
88 | """Make service info for an ESPHome host."""
89 | service_name = f"{host}.{SERVICE_TYPE}"
90 | server = f"{host}.local."
91 | return AsyncServiceInfo(SERVICE_TYPE, service_name, server=server)
92 |
93 |
94 | async def _async_resolve_short_host_zeroconf(
95 | aiozc: AsyncZeroconf,
96 | short_host: str,
97 | port: int,
98 | *,
99 | timeout: float = 3.0,
100 | ) -> list[AddrInfo]:
101 | _LOGGER.debug("Resolving host %s via mDNS", short_host)
102 | service_info = await _async_zeroconf_get_service_info(aiozc, short_host, timeout)
103 | return service_info_to_addr_info(service_info, port)
104 |
105 |
106 | def service_info_to_addr_info(info: AsyncServiceInfo, port: int) -> list[AddrInfo]:
107 | return [
108 | _async_ip_address_to_addrinfo(ip, port)
109 | for version in (IPVersion.V6Only, IPVersion.V4Only)
110 | for ip in info.ip_addresses_by_version(version)
111 | ]
112 |
113 |
114 | async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo]:
115 | loop = asyncio.get_running_loop()
116 | try:
117 | # Limit to TCP IP protocol and SOCK_STREAM
118 | res = await loop.getaddrinfo(
119 | host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
120 | )
121 | except OSError as err:
122 | raise ResolveAPIError(f"Error resolving {host} to IP address: {err}")
123 |
124 | addrs: list[AddrInfo] = []
125 | for family, type_, proto, _, raw in res:
126 | sockaddr: IPv4Sockaddr | IPv6Sockaddr
127 | if family == socket.AF_INET:
128 | raw = cast(tuple[str, int], raw)
129 | address, port = raw
130 | sockaddr = IPv4Sockaddr(address=address, port=port)
131 | elif family == socket.AF_INET6:
132 | raw = cast(tuple[str, int, int, int], raw)
133 | address, port, flowinfo, scope_id = raw
134 | sockaddr = IPv6Sockaddr(
135 | address=address, port=port, flowinfo=flowinfo, scope_id=scope_id
136 | )
137 | else:
138 | # Unknown family
139 | continue
140 |
141 | addrs.append(
142 | AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr)
143 | )
144 | return addrs
145 |
146 |
147 | def _async_ip_address_to_addrinfo(ip: IPv4Address | IPv6Address, port: int) -> AddrInfo:
148 | """Convert an ipaddress to AddrInfo."""
149 | is_ipv6 = ip.version == 6
150 | sockaddr: IPv6Sockaddr | IPv4Sockaddr
151 | if is_ipv6:
152 | if TYPE_CHECKING:
153 | assert isinstance(ip, IPv6Address)
154 | sockaddr = IPv6Sockaddr(
155 | address=str(ip).partition("%")[0],
156 | port=port,
157 | flowinfo=0,
158 | scope_id=_scope_id_to_int(ip.scope_id),
159 | )
160 | else:
161 | sockaddr = IPv4Sockaddr(
162 | address=str(ip),
163 | port=port,
164 | )
165 |
166 | return AddrInfo(
167 | family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
168 | type=socket.SOCK_STREAM,
169 | proto=socket.IPPROTO_TCP,
170 | sockaddr=sockaddr,
171 | )
172 |
173 |
174 | def host_is_local_name(host: str) -> bool:
175 | """Check if the host is a local name."""
176 | return host_is_name_part(host) or address_is_local(host)
177 |
178 |
179 | async def async_resolve_host(
180 | hosts: list[str],
181 | port: int,
182 | zeroconf_manager: ZeroconfManager | None = None,
183 | timeout: float = RESOLVE_TIMEOUT,
184 | ) -> list[AddrInfo]:
185 | """Resolve hosts in parallel.
186 |
187 | We will try to resolve the host in the following order:
188 | - If the host is an IP address, we will return that and skip
189 | trying to resolve it at all.
190 |
191 | - If the host is a local name, we will try to resolve it via mDNS
192 | - Otherwise, we will use getaddrinfo to resolve it as well
193 |
194 | Once we know which hosts to resolve and which methods, all
195 | resolution runs in parallel and we will return the first
196 | result we get for each host.
197 | """
198 | manager: ZeroconfManager | None = None
199 | had_zeroconf_instance: bool = False
200 | resolve_results: defaultdict[str, list[AddrInfo]] = defaultdict(list)
201 | aiozc: AsyncZeroconf | None = None
202 | tried_to_create_zeroconf: bool = False
203 | exceptions: list[BaseException] = []
204 |
205 | # First try to handle the cases where we do not need to
206 | # do any network calls at all.
207 | # - If the host is an IP address, we can just return that
208 | # - If we have a zeroconf manager and the host is in the cache
209 | # we can return that as well
210 | for host in hosts:
211 | # If its an IP address, we can convert it to an AddrInfo
212 | # and we are done with this host
213 | try:
214 | ip_addr_info = _async_ip_address_to_addrinfo(ip_address(host), port)
215 | except ValueError:
216 | pass
217 | else:
218 | if ip_addr_info:
219 | resolve_results[host].append(ip_addr_info)
220 | continue
221 |
222 | if not host_is_local_name(host):
223 | continue
224 |
225 | # If its a local name, we can try to fetch it from the zeroconf cache
226 | if not tried_to_create_zeroconf:
227 | tried_to_create_zeroconf = True
228 | manager = zeroconf_manager or ZeroconfManager()
229 | had_zeroconf_instance = manager.has_instance
230 | try:
231 | aiozc = manager.get_async_zeroconf()
232 | except Exception as original_exc:
233 | new_exc = ResolveAPIError(
234 | f"Cannot start mDNS sockets while resolving {host}: "
235 | f"{original_exc}, is this a docker container "
236 | "without host network mode? "
237 | )
238 | new_exc.__cause__ = original_exc
239 | exceptions.append(new_exc)
240 |
241 | if aiozc:
242 | short_host = host.partition(".")[0]
243 | service_info = _make_service_info_for_short_host(short_host)
244 | if service_info.load_from_cache(aiozc.zeroconf) and (
245 | addr_infos := service_info_to_addr_info(service_info, port)
246 | ):
247 | resolve_results[host].extend(addr_infos)
248 |
249 | try:
250 | if len(resolve_results) != len(hosts):
251 | # If we have not resolved all hosts yet, we need to do some network calls
252 | try:
253 | async with asyncio_timeout(timeout):
254 | await _async_resolve_host(
255 | hosts, port, resolve_results, exceptions, aiozc, timeout
256 | )
257 | except asyncio.TimeoutError as err:
258 | raise ResolveTimeoutAPIError(
259 | f"Timeout while resolving IP address for {hosts}"
260 | ) from err
261 | finally:
262 | if manager and not had_zeroconf_instance:
263 | await asyncio.shield(create_eager_task(manager.async_close()))
264 |
265 | if addrs := list(itertools.chain.from_iterable(resolve_results.values())):
266 | return addrs
267 |
268 | if exceptions:
269 | raise ResolveAPIError(" ,".join([str(exc) for exc in exceptions]))
270 | raise ResolveAPIError(f"Could not resolve host {hosts} - got no results from OS")
271 |
272 |
273 | async def _async_resolve_host(
274 | hosts: list[str],
275 | port: int,
276 | resolve_results: defaultdict[str, list[AddrInfo]],
277 | exceptions: list[BaseException],
278 | aiozc: AsyncZeroconf | None,
279 | timeout: float,
280 | ) -> None:
281 | """Resolve hosts in parallel.
282 |
283 | As soon as we get a result for a host, we will cancel
284 | all other tasks trying to resolve that host.
285 |
286 | This function will resolve hosts in parallel using
287 | both mDNS and getaddrinfo.
288 |
289 | This function is also designed to be cancellable, so
290 | if we get cancelled, we will cancel all tasks, and
291 | clean up after ourselves.
292 | """
293 | resolve_task_to_host: dict[asyncio.Task[list[AddrInfo]], str] = {}
294 | host_tasks: defaultdict[str, set[asyncio.Task[list[AddrInfo]]]] = defaultdict(set)
295 |
296 | try:
297 | for host in hosts:
298 | coros: list[Coroutine[Any, Any, list[AddrInfo]]] = []
299 | if aiozc and host_is_local_name(host):
300 | short_host = host.partition(".")[0]
301 | coros.append(
302 | _async_resolve_short_host_zeroconf(
303 | aiozc, short_host, port, timeout=timeout
304 | )
305 | )
306 |
307 | coros.append(_async_resolve_host_getaddrinfo(host, port))
308 |
309 | for coro in coros:
310 | task = create_eager_task(coro)
311 | if task.done():
312 | if exc := task.exception():
313 | exceptions.append(exc)
314 | else:
315 | resolve_results[host].extend(task.result())
316 | else:
317 | resolve_task_to_host[task] = host
318 | host_tasks[host].add(task)
319 |
320 | while resolve_task_to_host:
321 | done, _ = await asyncio.wait(
322 | resolve_task_to_host,
323 | return_when=asyncio.FIRST_COMPLETED,
324 | )
325 | finished_hosts: set[str] = set()
326 | for task in done:
327 | host = resolve_task_to_host.pop(task)
328 | host_tasks[host].discard(task)
329 | if exc := task.exception():
330 | exceptions.append(exc)
331 | elif result := task.result():
332 | resolve_results[host].extend(result)
333 | finished_hosts.add(host)
334 |
335 | # We got a result for a host, cancel
336 | # any other tasks trying to resolve
337 | # it as we are done with that host
338 | for host in finished_hosts:
339 | for task in host_tasks.pop(host, ()):
340 | resolve_task_to_host.pop(task, None)
341 | task.cancel()
342 | with suppress(asyncio.CancelledError):
343 | await task
344 | finally:
345 | # We likely get here if we get cancelled
346 | # because of a timeout
347 | for task in resolve_task_to_host:
348 | task.cancel()
349 |
350 | # Await all remaining tasks only after cancelling
351 | # them in case we get cancelled ourselves
352 | for task in resolve_task_to_host:
353 | with suppress(asyncio.CancelledError):
354 | await task
355 |
--------------------------------------------------------------------------------
/aioesphomeapi/log_parser.py:
--------------------------------------------------------------------------------
1 | """Log parser for ESPHome log messages with ANSI color support."""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Iterable
6 | import re
7 |
8 | # Pre-compiled regex for ANSI escape sequences
9 | ANSI_ESCAPE = re.compile(
10 | r"(?:\x1B[@-Z\\-_]|[\x80-\x9A\x9C-\x9F]|(?:\x1B\[|\x9B)[0-?]*[ -/]*[@-~])"
11 | )
12 |
13 | # ANSI reset sequences
14 | ANSI_RESET_CODES = ("\033[0m", "\x1b[0m")
15 | ANSI_RESET = "\033[0m"
16 |
17 |
18 | def _extract_prefix_and_color(line: str, strip_ansi: bool) -> tuple[str, str, str]:
19 | """Extract ESPHome prefix and ANSI color code from line.
20 |
21 | Returns:
22 | Tuple of (prefix, color_code, line_without_color)
23 | """
24 | color_code = ""
25 | line_no_color = line
26 |
27 | # Extract ANSI color code at the beginning if present
28 | if not strip_ansi and (color_match := ANSI_ESCAPE.match(line)):
29 | color_code = color_match.group(0)
30 | line_no_color = line[len(color_code) :]
31 |
32 | # Find the ESPHome prefix
33 | bracket_colon = line_no_color.find("]:")
34 | prefix = line_no_color[: bracket_colon + 2] if bracket_colon != -1 else ""
35 |
36 | return prefix, color_code, line_no_color
37 |
38 |
39 | def _needs_reset(line: str) -> bool:
40 | """Check if line needs ANSI reset code appended."""
41 | return bool(
42 | line
43 | and not line.endswith(ANSI_RESET_CODES)
44 | and ("\033[" in line or "\x1b[" in line)
45 | )
46 |
47 |
48 | def _format_continuation_line(
49 | timestamp: str,
50 | prefix: str,
51 | line: str,
52 | color_code: str = "",
53 | strip_ansi: bool = False,
54 | ) -> str:
55 | """Format a continuation line with prefix and optional color."""
56 | line_content = f"{prefix} {line}" if prefix else line
57 |
58 | if color_code and not strip_ansi:
59 | reset = "" if line.endswith(ANSI_RESET_CODES) else ANSI_RESET
60 | return f"{timestamp}{color_code}{line_content}{reset}"
61 |
62 | return f"{timestamp}{line_content}"
63 |
64 |
65 | class LogParser:
66 | """Stateful parser for processing log messages one line at a time.
67 |
68 | This parser is designed for streaming input where log messages come
69 | line by line rather than in complete multi-line blocks.
70 | """
71 |
72 | def __init__(self, strip_ansi_escapes: bool = False) -> None:
73 | """Initialize the parser.
74 |
75 | Args:
76 | strip_ansi_escapes: If True, remove all ANSI escape sequences from output
77 | """
78 | self.strip_ansi_escapes = strip_ansi_escapes
79 | self._current_prefix = ""
80 | self._current_color_code = ""
81 |
82 | def parse_line(self, line: str, timestamp: str) -> str:
83 | """Parse a single line and return formatted output.
84 |
85 | Args:
86 | line: A single line of log text (without newline)
87 | timestamp: The timestamp string to prepend (e.g., "[08:00:00.000]")
88 |
89 | Returns:
90 | Formatted line ready to be printed.
91 | """
92 | # Strip any trailing newline if present
93 | line = line.rstrip("\n\r")
94 |
95 | # Strip ANSI escapes if requested
96 | if self.strip_ansi_escapes:
97 | line = ANSI_ESCAPE.sub("", line)
98 |
99 | # Empty line handling
100 | if not line:
101 | return ""
102 |
103 | # Check if this is a new log entry or a continuation
104 | is_continuation = line[0].isspace()
105 |
106 | if not is_continuation:
107 | # This is a new log entry - update state
108 | self._current_prefix = ""
109 | self._current_color_code = ""
110 |
111 | # Extract prefix and color for potential multi-line messages
112 | if line and not line[0].isspace():
113 | self._current_prefix, self._current_color_code, _ = (
114 | _extract_prefix_and_color(line, self.strip_ansi_escapes)
115 | )
116 |
117 | # Format the first line
118 | output = f"{timestamp}{line}"
119 |
120 | # Add reset if line has color but no reset at end
121 | if not self.strip_ansi_escapes and _needs_reset(line):
122 | output += ANSI_RESET
123 |
124 | return output
125 |
126 | # This is a continuation line
127 | if not line.strip():
128 | return ""
129 |
130 | return _format_continuation_line(
131 | timestamp,
132 | self._current_prefix,
133 | line,
134 | self._current_color_code,
135 | self.strip_ansi_escapes,
136 | )
137 |
138 |
139 | def parse_log_message(
140 | text: str, timestamp: str, *, strip_ansi_escapes: bool = False
141 | ) -> Iterable[str]:
142 | """Parse a log message and format it with timestamps and color preservation.
143 |
144 | Args:
145 | text: The log message text, potentially with ANSI codes and newlines
146 | timestamp: The timestamp string to prepend (e.g., "[08:00:00.000]")
147 | strip_ansi_escapes: If True, remove all ANSI escape sequences from output
148 |
149 | Returns:
150 | Iterable of formatted lines ready to be printed.
151 | For single-line logs, returns a tuple for efficiency.
152 | For multi-line logs, returns a list.
153 | """
154 | # Strip ANSI escapes if requested
155 | if strip_ansi_escapes:
156 | text = ANSI_ESCAPE.sub("", text)
157 |
158 | # Fast path for single line (most common case)
159 | if "\n" not in text:
160 | return (f"{timestamp}{text}",)
161 |
162 | # Multi-line handling
163 | lines = text.split("\n")
164 |
165 | # Remove trailing empty line or ANSI reset codes
166 | if lines and (lines[-1] == "" or lines[-1] in ANSI_RESET_CODES):
167 | lines.pop()
168 | result: list[str] = []
169 |
170 | # Process the first line
171 | first_line_output = f"{timestamp}{lines[0]}"
172 |
173 | # Check if first line has color but no reset at end (to prevent bleeding)
174 | if not strip_ansi_escapes and _needs_reset(lines[0]):
175 | first_line_output += ANSI_RESET
176 |
177 | result.append(first_line_output)
178 |
179 | # Extract prefix and color from the first line
180 | first_line = lines[0]
181 | prefix = ""
182 | color_code = ""
183 |
184 | # Extract prefix if first line doesn't start with space
185 | if first_line and not first_line[0].isspace():
186 | prefix, color_code, _ = _extract_prefix_and_color(
187 | first_line, strip_ansi_escapes
188 | )
189 |
190 | # Process subsequent lines
191 | for line in lines[1:]:
192 | if not line.strip(): # Only process non-empty lines
193 | # Empty line
194 | result.append("")
195 | continue
196 | if not line[0].isspace(): # If line starts with whitespace, it's a continuation
197 | # This is a new log entry within the same message
198 | result.append(f"{timestamp}{line}")
199 | continue
200 | # Apply timestamp, color, prefix, and the continuation line
201 | result.append(
202 | _format_continuation_line(
203 | timestamp, prefix, line, color_code, strip_ansi_escapes
204 | )
205 | )
206 |
207 | return result
208 |
--------------------------------------------------------------------------------
/aioesphomeapi/log_reader.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | # Helper script and aioesphomeapi to view logs from an esphome device
4 | import argparse
5 | import asyncio
6 | import contextlib
7 | from datetime import datetime
8 | import logging
9 | import sys
10 |
11 | from .api_pb2 import SubscribeLogsResponse # type: ignore
12 | from .client import APIClient
13 | from .log_parser import parse_log_message
14 | from .log_runner import async_run
15 |
16 |
17 | async def main(argv: list[str]) -> None:
18 | parser = argparse.ArgumentParser("aioesphomeapi-logs")
19 | parser.add_argument("--port", type=int, default=6053)
20 | parser.add_argument("--password", type=str)
21 | parser.add_argument("--noise-psk", type=str)
22 | parser.add_argument("-v", "--verbose", action="store_true")
23 | parser.add_argument("address")
24 | args = parser.parse_args(argv[1:])
25 |
26 | logging.basicConfig(
27 | format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
28 | level=logging.DEBUG if args.verbose else logging.INFO,
29 | datefmt="%Y-%m-%d %H:%M:%S",
30 | )
31 |
32 | cli = APIClient(
33 | args.address,
34 | args.port,
35 | args.password or "",
36 | noise_psk=args.noise_psk,
37 | keepalive=10,
38 | )
39 |
40 | def on_log(msg: SubscribeLogsResponse) -> None:
41 | time_ = datetime.now()
42 | message: bytes = msg.message
43 | text = message.decode("utf8", "backslashreplace")
44 | nanoseconds = time_.microsecond // 1000
45 | timestamp = (
46 | f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]"
47 | )
48 |
49 | # Parse and print the log message
50 | for line in parse_log_message(text, timestamp):
51 | print(line)
52 |
53 | stop = await async_run(cli, on_log)
54 | try:
55 | await asyncio.Event().wait()
56 | finally:
57 | await stop()
58 |
59 |
60 | def cli_entry_point() -> None:
61 | """Run the CLI."""
62 | with contextlib.suppress(KeyboardInterrupt):
63 | asyncio.run(main(sys.argv))
64 |
65 |
66 | if __name__ == "__main__":
67 | cli_entry_point()
68 | sys.exit(0)
69 |
--------------------------------------------------------------------------------
/aioesphomeapi/log_runner.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Coroutine
4 | import logging
5 | from typing import Any, Callable
6 |
7 | from zeroconf.asyncio import AsyncZeroconf
8 |
9 | from .api_pb2 import SubscribeLogsResponse # type: ignore
10 | from .client import APIClient
11 | from .core import APIConnectionError
12 | from .model import LogLevel
13 | from .reconnect_logic import ReconnectLogic
14 |
15 | _LOGGER = logging.getLogger(__name__)
16 |
17 |
18 | async def async_run(
19 | cli: APIClient,
20 | on_log: Callable[[SubscribeLogsResponse], None],
21 | log_level: LogLevel = LogLevel.LOG_LEVEL_VERY_VERBOSE,
22 | aio_zeroconf_instance: AsyncZeroconf | None = None,
23 | dump_config: bool = True,
24 | name: str | None = None,
25 | ) -> Callable[[], Coroutine[Any, Any, None]]:
26 | """Run logs until canceled.
27 |
28 | Returns a coroutine that can be awaited to stop the logs.
29 | """
30 | dumped_config = not dump_config
31 |
32 | async def on_connect() -> None:
33 | """Handle a connection."""
34 | nonlocal dumped_config
35 | try:
36 | cli.subscribe_logs(
37 | on_log,
38 | log_level=log_level,
39 | dump_config=not dumped_config,
40 | )
41 | dumped_config = True
42 | except APIConnectionError:
43 | await cli.disconnect()
44 |
45 | async def on_disconnect( # pylint: disable=unused-argument
46 | expected_disconnect: bool,
47 | ) -> None:
48 | _LOGGER.warning("Disconnected from API")
49 |
50 | logic = ReconnectLogic(
51 | client=cli,
52 | on_connect=on_connect,
53 | on_disconnect=on_disconnect,
54 | zeroconf_instance=aio_zeroconf_instance,
55 | name=name,
56 | )
57 | await logic.start()
58 |
59 | async def _stop() -> None:
60 | await logic.stop()
61 | await cli.disconnect()
62 |
63 | return _stop
64 |
--------------------------------------------------------------------------------
/aioesphomeapi/model_conversions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from .api_pb2 import ( # type: ignore
6 | AlarmControlPanelStateResponse,
7 | BinarySensorStateResponse,
8 | ClimateStateResponse,
9 | CoverStateResponse,
10 | DateStateResponse,
11 | DateTimeStateResponse,
12 | EventResponse,
13 | FanStateResponse,
14 | LightStateResponse,
15 | ListEntitiesAlarmControlPanelResponse,
16 | ListEntitiesBinarySensorResponse,
17 | ListEntitiesButtonResponse,
18 | ListEntitiesCameraResponse,
19 | ListEntitiesClimateResponse,
20 | ListEntitiesCoverResponse,
21 | ListEntitiesDateResponse,
22 | ListEntitiesDateTimeResponse,
23 | ListEntitiesEventResponse,
24 | ListEntitiesFanResponse,
25 | ListEntitiesLightResponse,
26 | ListEntitiesLockResponse,
27 | ListEntitiesMediaPlayerResponse,
28 | ListEntitiesNumberResponse,
29 | ListEntitiesSelectResponse,
30 | ListEntitiesSensorResponse,
31 | ListEntitiesServicesResponse,
32 | ListEntitiesSirenResponse,
33 | ListEntitiesSwitchResponse,
34 | ListEntitiesTextResponse,
35 | ListEntitiesTextSensorResponse,
36 | ListEntitiesTimeResponse,
37 | ListEntitiesUpdateResponse,
38 | ListEntitiesValveResponse,
39 | LockStateResponse,
40 | MediaPlayerStateResponse,
41 | NumberStateResponse,
42 | SelectStateResponse,
43 | SensorStateResponse,
44 | SirenStateResponse,
45 | SwitchStateResponse,
46 | TextSensorStateResponse,
47 | TextStateResponse,
48 | TimeStateResponse,
49 | UpdateStateResponse,
50 | ValveStateResponse,
51 | )
52 | from .model import (
53 | AlarmControlPanelEntityState,
54 | AlarmControlPanelInfo,
55 | BinarySensorInfo,
56 | BinarySensorState,
57 | ButtonInfo,
58 | CameraInfo,
59 | ClimateInfo,
60 | ClimateState,
61 | CoverInfo,
62 | CoverState,
63 | DateInfo,
64 | DateState,
65 | DateTimeInfo,
66 | DateTimeState,
67 | EntityInfo,
68 | EntityState,
69 | Event,
70 | EventInfo,
71 | FanInfo,
72 | FanState,
73 | LightInfo,
74 | LightState,
75 | LockEntityState,
76 | LockInfo,
77 | MediaPlayerEntityState,
78 | MediaPlayerInfo,
79 | NumberInfo,
80 | NumberState,
81 | SelectInfo,
82 | SelectState,
83 | SensorInfo,
84 | SensorState,
85 | SirenInfo,
86 | SirenState,
87 | SwitchInfo,
88 | SwitchState,
89 | TextInfo,
90 | TextSensorInfo,
91 | TextSensorState,
92 | TextState,
93 | TimeInfo,
94 | TimeState,
95 | UpdateInfo,
96 | UpdateState,
97 | ValveInfo,
98 | ValveState,
99 | )
100 |
101 | SUBSCRIBE_STATES_RESPONSE_TYPES: dict[Any, type[EntityState]] = {
102 | AlarmControlPanelStateResponse: AlarmControlPanelEntityState,
103 | BinarySensorStateResponse: BinarySensorState,
104 | ClimateStateResponse: ClimateState,
105 | CoverStateResponse: CoverState,
106 | DateStateResponse: DateState,
107 | DateTimeStateResponse: DateTimeState,
108 | EventResponse: Event,
109 | FanStateResponse: FanState,
110 | LightStateResponse: LightState,
111 | LockStateResponse: LockEntityState,
112 | MediaPlayerStateResponse: MediaPlayerEntityState,
113 | NumberStateResponse: NumberState,
114 | SelectStateResponse: SelectState,
115 | SensorStateResponse: SensorState,
116 | SirenStateResponse: SirenState,
117 | SwitchStateResponse: SwitchState,
118 | TextSensorStateResponse: TextSensorState,
119 | TextStateResponse: TextState,
120 | TimeStateResponse: TimeState,
121 | UpdateStateResponse: UpdateState,
122 | ValveStateResponse: ValveState,
123 | }
124 |
125 | LIST_ENTITIES_SERVICES_RESPONSE_TYPES: dict[Any, type[EntityInfo] | None] = {
126 | ListEntitiesAlarmControlPanelResponse: AlarmControlPanelInfo,
127 | ListEntitiesBinarySensorResponse: BinarySensorInfo,
128 | ListEntitiesButtonResponse: ButtonInfo,
129 | ListEntitiesCameraResponse: CameraInfo,
130 | ListEntitiesClimateResponse: ClimateInfo,
131 | ListEntitiesCoverResponse: CoverInfo,
132 | ListEntitiesDateResponse: DateInfo,
133 | ListEntitiesDateTimeResponse: DateTimeInfo,
134 | ListEntitiesEventResponse: EventInfo,
135 | ListEntitiesFanResponse: FanInfo,
136 | ListEntitiesLightResponse: LightInfo,
137 | ListEntitiesLockResponse: LockInfo,
138 | ListEntitiesMediaPlayerResponse: MediaPlayerInfo,
139 | ListEntitiesNumberResponse: NumberInfo,
140 | ListEntitiesSelectResponse: SelectInfo,
141 | ListEntitiesSensorResponse: SensorInfo,
142 | ListEntitiesServicesResponse: None,
143 | ListEntitiesSirenResponse: SirenInfo,
144 | ListEntitiesSwitchResponse: SwitchInfo,
145 | ListEntitiesTextResponse: TextInfo,
146 | ListEntitiesTextSensorResponse: TextSensorInfo,
147 | ListEntitiesTimeResponse: TimeInfo,
148 | ListEntitiesUpdateResponse: UpdateInfo,
149 | ListEntitiesValveResponse: ValveInfo,
150 | }
151 |
--------------------------------------------------------------------------------
/aioesphomeapi/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/esphome/aioesphomeapi/a74676a3176f3d34d7a3345367f24eba0f160b61/aioesphomeapi/py.typed
--------------------------------------------------------------------------------
/aioesphomeapi/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from asyncio import AbstractEventLoop, Task, get_running_loop
4 | from collections.abc import Coroutine
5 | import math
6 | import sys
7 | from typing import Any, TypeVar
8 |
9 | _T = TypeVar("_T")
10 |
11 |
12 | if sys.version_info[:2] < (3, 11):
13 | from async_timeout import timeout as asyncio_timeout
14 | else:
15 | from asyncio import timeout as asyncio_timeout
16 |
17 |
18 | def fix_float_single_double_conversion(value: float) -> float:
19 | """Fix precision for single-precision floats and return what was probably
20 | meant as a float.
21 |
22 | In ESPHome we work with single-precision floats internally for performance.
23 | But python uses double-precision floats, and when protobuf reads the message
24 | it's auto-converted to a double (which is possible losslessly).
25 |
26 | Unfortunately the float representation of 0.1 converted to a double is not the
27 | double representation of 0.1, but 0.10000000149011612.
28 |
29 | This methods tries to round to the closest decimal value that a float of this
30 | magnitude can accurately represent.
31 | """
32 | if value == 0 or not math.isfinite(value):
33 | return value
34 | abs_val = abs(value)
35 | # assume ~7 decimals of precision for floats to be safe
36 | l10 = math.ceil(math.log10(abs_val))
37 | prec = 7 - l10
38 | return round(value, prec)
39 |
40 |
41 | def host_is_name_part(address: str) -> bool:
42 | """Return True if a host is the name part."""
43 | return "." not in address and ":" not in address
44 |
45 |
46 | def address_is_local(address: str) -> bool:
47 | """Return True if the address is a local address."""
48 | return address.removesuffix(".").endswith(".local")
49 |
50 |
51 | def build_log_name(
52 | name: str | None, addresses: list[str], connected_address: str | None
53 | ) -> str:
54 | """Return a log name for a connection."""
55 | preferred_address = connected_address
56 | for address in addresses:
57 | if (not name and address_is_local(address)) or host_is_name_part(address):
58 | name = address.partition(".")[0]
59 | elif not preferred_address:
60 | preferred_address = address
61 | if not preferred_address:
62 | return name or addresses[0]
63 | if (
64 | name
65 | and name != preferred_address
66 | and not preferred_address.startswith(f"{name}.")
67 | ):
68 | return f"{name} @ {preferred_address}"
69 | return preferred_address
70 |
71 |
72 | if sys.version_info >= (3, 12, 0):
73 |
74 | def create_eager_task(
75 | coro: Coroutine[Any, Any, _T],
76 | *,
77 | name: str | None = None,
78 | loop: AbstractEventLoop | None = None,
79 | ) -> Task[_T]:
80 | """Create a task from a coroutine and schedule it to run immediately."""
81 | return Task(
82 | coro,
83 | loop=loop or get_running_loop(),
84 | name=name,
85 | eager_start=True, # type: ignore[call-arg]
86 | )
87 |
88 | else:
89 |
90 | def create_eager_task(
91 | coro: Coroutine[Any, Any, _T],
92 | *,
93 | name: str | None = None,
94 | loop: AbstractEventLoop | None = None,
95 | ) -> Task[_T]:
96 | """Create a task from a coroutine."""
97 | return Task(coro, loop=loop or get_running_loop(), name=name)
98 |
99 |
100 | __all__ = (
101 | "address_is_local",
102 | "asyncio_timeout",
103 | "build_log_name",
104 | "create_eager_task",
105 | "fix_float_single_double_conversion",
106 | "host_is_name_part",
107 | )
108 |
--------------------------------------------------------------------------------
/aioesphomeapi/zeroconf.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import TYPE_CHECKING, Union
5 |
6 | from zeroconf import Zeroconf
7 | from zeroconf.asyncio import AsyncZeroconf
8 |
9 | ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf]
10 |
11 | _LOGGER = logging.getLogger(__name__)
12 |
13 |
14 | class ZeroconfManager:
15 | """Manage the Zeroconf objects.
16 |
17 | This class is used to manage the Zeroconf objects. It is used to create
18 | the Zeroconf objects and to close them. It attempts to avoid creating
19 | a Zeroconf object unless one is actually needed.
20 | """
21 |
22 | def __init__(self, zeroconf: ZeroconfInstanceType | None = None) -> None:
23 | """Initialize the ZeroconfManager."""
24 | self._created = False
25 | self._aiozc: AsyncZeroconf | None = None
26 | if zeroconf is not None:
27 | self.set_instance(zeroconf)
28 |
29 | @property
30 | def has_instance(self) -> bool:
31 | """Return True if a Zeroconf instance is set."""
32 | return self._aiozc is not None
33 |
34 | def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None:
35 | """Set the AsyncZeroconf instance."""
36 | if self._aiozc:
37 | if isinstance(zc, AsyncZeroconf) and self._aiozc.zeroconf is zc.zeroconf:
38 | return
39 | if isinstance(zc, Zeroconf) and self._aiozc.zeroconf is zc:
40 | self._aiozc = AsyncZeroconf(zc=zc)
41 | return
42 | raise RuntimeError("Zeroconf instance already set to a different instance")
43 | self._aiozc = zc if isinstance(zc, AsyncZeroconf) else AsyncZeroconf(zc=zc)
44 |
45 | def _create_async_zeroconf(self) -> None:
46 | """Create an AsyncZeroconf instance."""
47 | _LOGGER.debug("Creating new AsyncZeroconf instance")
48 | self._aiozc = AsyncZeroconf()
49 | self._created = True
50 |
51 | def get_async_zeroconf(self) -> AsyncZeroconf:
52 | """Get the AsyncZeroconf instance."""
53 | if not self._aiozc:
54 | self._create_async_zeroconf()
55 | if TYPE_CHECKING:
56 | assert self._aiozc is not None
57 | return self._aiozc
58 |
59 | async def async_close(self) -> None:
60 | """Close the Zeroconf connection."""
61 | if not self._created or not self._aiozc:
62 | return
63 | await self._aiozc.async_close()
64 | self._aiozc = None
65 | self._created = False
66 |
--------------------------------------------------------------------------------
/bench/raw_ble_plain_text.py:
--------------------------------------------------------------------------------
1 | import timeit
2 |
3 | from aioesphomeapi import APIConnection
4 | from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
5 | from aioesphomeapi._frame_helper.packets import _cached_varuint_to_bytes
6 | from aioesphomeapi.api_pb2 import (
7 | BluetoothLERawAdvertisement,
8 | BluetoothLERawAdvertisementsResponse,
9 | )
10 |
11 | # cythonize -X language_level=3 -a -i aioesphomeapi/_frame_helper/plain_text.py
12 | # cythonize -X language_level=3 -a -i aioesphomeapi/_frame_helper/base.py
13 | # cythonize -X language_level=3 -a -i aioesphomeapi/connection.py
14 |
15 | adv = BluetoothLERawAdvertisementsResponse()
16 | fake_adv = BluetoothLERawAdvertisement(
17 | address=1,
18 | rssi=-86,
19 | address_type=2,
20 | data=(
21 | b"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75"
22 | b"657a2f686369302f64656c04010134000000e25389019500000001016f002500"
23 | b"00002f6f72672f626c75657a2f686369302f6465"
24 | ),
25 | )
26 | for i in range(5):
27 | adv.advertisements.append(fake_adv)
28 |
29 | type_ = 93
30 | data = adv.SerializeToString()
31 | data = (
32 | b"\0" + _cached_varuint_to_bytes(len(data)) + _cached_varuint_to_bytes(type_) + data
33 | )
34 |
35 |
36 | class MockConnection(APIConnection):
37 | def __init__(self, *args, **kwargs):
38 | pass
39 |
40 | def process_packet(self, type_: int, data: bytes):
41 | pass
42 |
43 | def report_fatal_error(self, exc: Exception):
44 | raise exc
45 |
46 |
47 | connection = MockConnection()
48 |
49 | helper = APIPlaintextFrameHelper(
50 | connection=connection, client_info="my client", log_name="test"
51 | )
52 |
53 |
54 | def process_incoming_msg():
55 | helper.data_received(data)
56 |
57 |
58 | count = 3000000
59 | time = timeit.Timer(process_incoming_msg).timeit(count)
60 | print(f"Processed {count} bluetooth messages took {time} seconds")
61 |
--------------------------------------------------------------------------------
/bench/raw_ble_plain_text_with_callback.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import timeit
3 |
4 | from aioesphomeapi import APIConnection
5 | from aioesphomeapi.api_pb2 import (
6 | BluetoothLERawAdvertisement,
7 | BluetoothLERawAdvertisementsResponse,
8 | )
9 | from aioesphomeapi.client import APIClient
10 | from aioesphomeapi.client_base import on_ble_raw_advertisement_response
11 |
12 | # cythonize -X language_level=3 -a -i aioesphomeapi/client_base.py
13 | # cythonize -X language_level=3 -a -i aioesphomeapi/connection.py
14 |
15 |
16 | class MockConnection(APIConnection):
17 | pass
18 |
19 |
20 | client = APIClient("fake.address", 6052, None)
21 | connection = MockConnection(
22 | client._params, lambda expected_disconnect: None, False, None
23 | )
24 |
25 |
26 | def process_incoming_msg():
27 | connection.process_packet(
28 | 93,
29 | b'\n\xb2\x01\x08\x01\x10\xab\x01\x18\x02"\xa8\x016c04010134000000'
30 | b"e25389019500000001016f00250000002f6f72672f626c75657a2f686369302"
31 | b"f64656c04010134000000e25389019500000001016f00250000002f6f72672f"
32 | b"626c75657a2f686369302f6465\n\xb2\x01\x08\x01\x10\xab\x01\x18\x02"
33 | b'"\xa8\x016c04010134000000e25389019500000001016f00250000002f6f726'
34 | b"72f626c75657a2f686369302f64656c04010134000000e253890195000000010"
35 | b"16f00250000002f6f72672f626c75657a2f686369302f6465\n\xb2\x01\x08"
36 | b'\x01\x10\xab\x01\x18\x02"\xa8\x016c04010134000000e25389019500000'
37 | b"001016f00250000002f6f72672f626c75657a2f686369302f64656c040101340"
38 | b"00000e25389019500000001016f00250000002f6f72672f626c75657a2f68636"
39 | b'9302f6465\n\xb2\x01\x08\x01\x10\xab\x01\x18\x02"\xa8\x016c040101'
40 | b"34000000e25389019500000001016f00250000002f6f72672f626c75657a2f68"
41 | b"6369302f64656c04010134000000e25389019500000001016f00250000002f6f"
42 | b"72672f626c75657a2f686369302f6465\n\xb2\x01\x08\x01\x10\xab\x01"
43 | b'\x18\x02"\xa8\x016c04010134000000e25389019500000001016f002500000'
44 | b"02f6f72672f626c75657a2f686369302f64656c04010134000000e2538901950"
45 | b"0000001016f00250000002f6f72672f626c75657a2f686369302f6465",
46 | )
47 |
48 |
49 | def on_advertisements(msgs: list[BluetoothLERawAdvertisement]):
50 | pass
51 |
52 |
53 | connection.add_message_callback(
54 | partial(on_ble_raw_advertisement_response, on_advertisements),
55 | (BluetoothLERawAdvertisementsResponse,),
56 | )
57 |
58 | count = 3000000
59 | time = timeit.Timer(process_incoming_msg).timeit(count)
60 | print(f"Processed {count} bluetooth messages took {time} seconds")
61 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.9
3 | show_error_codes = true
4 |
5 | strict = true
6 | warn_unreachable = true
7 |
8 | [mypy-async_timeout.*]
9 | ignore_missing_imports = True
10 |
11 | [mypy-noise.*]
12 | ignore_missing_imports = True
13 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | required-version = ">=0.5.0"
3 | exclude = [
4 | "aioesphomeapi/api_pb2.py",
5 | "aioesphomeapi/api_options_pb2.py",
6 | ]
7 |
8 | [tool.ruff.lint]
9 | select = [
10 | "ASYNC", # async rules
11 | "E", # pycodestyle
12 | "F", # pyflakes/autoflake
13 | "FLY", # flynt
14 | "FURB", # refurb
15 | "G", # flake8-logging-format
16 | "I", # isort
17 | "PERF", # Perflint
18 | "PIE", # flake8-pie
19 | "PL", # pylint
20 | "UP", # pyupgrade
21 | "RUF", # ruff
22 | "SIM", # flake8-SIM
23 | "SLOT", # flake8-slots
24 | "TID", # Tidy imports
25 | "TRY", # try rules
26 | "PERF", # performance
27 | ]
28 |
29 | ignore = [
30 | "E501", # line too long
31 | "E721", # We want type() check for protobuf messages
32 | "PLR0911", # Too many return statements ({returns} > {max_returns})
33 | "PLR0912", # Too many branches ({branches} > {max_branches})
34 | "PLR0913", # Too many arguments to function call ({c_args} > {max_args})
35 | "PLR0915", # Too many statements ({statements} > {max_statements})
36 | "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable
37 | "PLW2901", # Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target
38 | "TRY003", # Too many to fix - Avoid specifying long messages outside the exception class
39 | "TID252", # Prefer absolute imports over relative imports from parent modules
40 | ]
41 |
42 | [tool.ruff.lint.isort]
43 | force-sort-within-sections = true
44 | known-first-party = [
45 | "aioesphomeapi", "tests"
46 | ]
47 | combine-as-imports = true
48 | split-on-trailing-comma = false
49 |
50 | [build-system]
51 | requires = ['setuptools>=65.4.1', 'wheel', 'Cython>=3.0.2']
52 |
53 | [tool.pytest.ini_options]
54 | asyncio_mode = "auto"
55 |
--------------------------------------------------------------------------------
/requirements/base.txt:
--------------------------------------------------------------------------------
1 | aiohappyeyeballs>=2.3.0
2 | async-interrupt>=1.2.0
3 | protobuf>=4
4 | zeroconf>=0.143.0,<1.0
5 | chacha20poly1305-reuseable>=0.13.2
6 | cryptography>=43.0.0
7 | noiseprotocol>=0.3.1,<1.0
8 | async-timeout>=4.0;python_version<'3.11'
9 |
--------------------------------------------------------------------------------
/requirements/test.txt:
--------------------------------------------------------------------------------
1 | pylint==3.3.7
2 | ruff==0.11.13
3 | flake8==7.2.0
4 | isort==6.0.1
5 | mypy==1.16.0
6 | types-protobuf==6.30.2.20250516
7 | pytest>=6.2.4,<9
8 | pytest-asyncio==0.26.0
9 | pytest-codspeed==3.2.0
10 | pytest-cov>=4.1.0
11 | pytest-timeout==2.4.0
12 |
--------------------------------------------------------------------------------
/script/gen-protoc:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | from pathlib import Path
5 | from subprocess import check_call
6 |
7 | root_dir = Path(__file__).absolute().parent.parent
8 | os.chdir(root_dir)
9 |
10 | check_call(
11 | [
12 | "protoc",
13 | "--python_out=aioesphomeapi",
14 | "-I",
15 | "aioesphomeapi",
16 | "aioesphomeapi/api.proto",
17 | "aioesphomeapi/api_options.proto",
18 | ]
19 | )
20 |
21 | # https://github.com/protocolbuffers/protobuf/issues/1491
22 | api_file = root_dir / "aioesphomeapi" / "api_pb2.py"
23 | content = api_file.read_text().replace(
24 | "import api_options_pb2 as api__options__pb2",
25 | "from . import api_options_pb2 as api__options__pb2",
26 | )
27 | api_file.write_text(content)
28 |
29 | for fname in ["api_pb2.py", "api_options_pb2.py"]:
30 | file = root_dir / "aioesphomeapi" / fname
31 | content = "# type: ignore\n" + file.read_text()
32 | file.write_text(content)
33 |
--------------------------------------------------------------------------------
/script/lint:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd "$(dirname "$0")/.."
4 | set -euxo pipefail
5 |
6 | black --safe aioesphomeapi tests
7 | pylint aioesphomeapi
8 | flake8 aioesphomeapi
9 | isort aioesphomeapi tests
10 | mypy aioesphomeapi
11 | pytest tests
12 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | # Following 4 for black compatibility
4 | # E501: line too long
5 | # W503: Line break occurred before a binary operator
6 | # E203: Whitespace before ':'
7 | # D202 No blank lines allowed after function docstring
8 |
9 | ignore =
10 | E501,
11 | W503,
12 | E203,
13 | D202,
14 |
15 | exclude = api_pb2.py, api_options_pb2.py
16 |
17 | [bdist_wheel]
18 | universal = 1
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """aioesphomeapi setup script."""
3 |
4 | import contextlib
5 | from distutils.command.build_ext import build_ext
6 | import os
7 | from typing import Any
8 |
9 | from setuptools import find_packages, setup
10 |
11 | try:
12 | from setuptools import Extension
13 | except ImportError:
14 | from distutils.core import Extension
15 |
16 | TO_CYTHONIZE = [
17 | "aioesphomeapi/client_base.py",
18 | "aioesphomeapi/connection.py",
19 | "aioesphomeapi/_frame_helper/base.py",
20 | "aioesphomeapi/_frame_helper/noise.py",
21 | "aioesphomeapi/_frame_helper/noise_encryption.py",
22 | "aioesphomeapi/_frame_helper/packets.py",
23 | "aioesphomeapi/_frame_helper/plain_text.py",
24 | "aioesphomeapi/_frame_helper/pack.pyx",
25 | ]
26 |
27 | EXTENSIONS = [
28 | Extension(
29 | ext.removesuffix(".py").removesuffix(".pyx").replace("/", "."),
30 | [ext],
31 | language="c",
32 | extra_compile_args=["-O3", "-g0"],
33 | )
34 | for ext in TO_CYTHONIZE
35 | ]
36 |
37 |
38 | here = os.path.abspath(os.path.dirname(__file__))
39 |
40 | with open(os.path.join(here, "README.rst"), encoding="utf-8") as readme_file:
41 | long_description = readme_file.read()
42 |
43 |
44 | VERSION = "32.2.2"
45 | PROJECT_NAME = "aioesphomeapi"
46 | PROJECT_PACKAGE_NAME = "aioesphomeapi"
47 | PROJECT_LICENSE = "MIT"
48 | PROJECT_AUTHOR = "Otto Winter"
49 | PROJECT_COPYRIGHT = " 2019-2020, Otto Winter"
50 | PROJECT_URL = "https://esphome.io/"
51 | PROJECT_EMAIL = "esphome@nabucasa.com"
52 |
53 | PROJECT_GITHUB_USERNAME = "esphome"
54 | PROJECT_GITHUB_REPOSITORY = "aioesphomeapi"
55 |
56 | PYPI_URL = f"https://pypi.python.org/pypi/{PROJECT_PACKAGE_NAME}"
57 | GITHUB_PATH = f"{PROJECT_GITHUB_USERNAME}/{PROJECT_GITHUB_REPOSITORY}"
58 | GITHUB_URL = f"https://github.com/{GITHUB_PATH}"
59 |
60 | DOWNLOAD_URL = f"{GITHUB_URL}/archive/{VERSION}.zip"
61 |
62 | with open(os.path.join(here, "requirements/base.txt")) as requirements_txt:
63 | REQUIRES = requirements_txt.read().splitlines()
64 |
65 | pkgs = find_packages(exclude=["tests", "tests.*"])
66 |
67 | setup_kwargs = {
68 | "name": PROJECT_PACKAGE_NAME,
69 | "version": VERSION,
70 | "url": PROJECT_URL,
71 | "download_url": DOWNLOAD_URL,
72 | "author": PROJECT_AUTHOR,
73 | "author_email": PROJECT_EMAIL,
74 | "description": "Python API for interacting with ESPHome devices.",
75 | "long_description": long_description,
76 | "license": PROJECT_LICENSE,
77 | "packages": pkgs,
78 | "exclude_package_data": {pkg: ["*.c"] for pkg in pkgs},
79 | "include_package_data": True,
80 | "zip_safe": False,
81 | "install_requires": REQUIRES,
82 | "python_requires": ">=3.9",
83 | "test_suite": "tests",
84 | "entry_points": {
85 | "console_scripts": [
86 | "aioesphomeapi-logs=aioesphomeapi.log_reader:cli_entry_point",
87 | "aioesphomeapi-discover=aioesphomeapi.discover:cli_entry_point",
88 | ],
89 | },
90 | }
91 |
92 |
93 | class OptionalBuildExt(build_ext):
94 | def build_extensions(self) -> None:
95 | with contextlib.suppress(Exception):
96 | super().build_extensions()
97 |
98 |
99 | def cythonize_if_available(setup_kwargs: dict[str, Any]) -> None:
100 | if os.environ.get("SKIP_CYTHON"):
101 | return
102 | try:
103 | from Cython.Build import cythonize
104 |
105 | setup_kwargs.update(
106 | dict(
107 | ext_modules=cythonize(
108 | EXTENSIONS,
109 | compiler_directives={"language_level": "3"}, # Python 3
110 | ),
111 | cmdclass=dict(build_ext=OptionalBuildExt),
112 | )
113 | )
114 | except Exception:
115 | if os.environ.get("REQUIRE_CYTHON"):
116 | raise
117 |
118 |
119 | cythonize_if_available(setup_kwargs)
120 |
121 | setup(**setup_kwargs)
122 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Init tests."""
2 |
3 | from __future__ import annotations
4 |
5 | import logging
6 |
7 | logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
8 |
--------------------------------------------------------------------------------
/tests/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | # Enable debug logging is not on for benchmarks
4 | logging.getLogger("aioesphomeapi").setLevel(logging.WARNING)
5 |
--------------------------------------------------------------------------------
/tests/benchmarks/conftest.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytest
4 |
5 |
6 | @pytest.fixture(autouse=True)
7 | def no_debug_logging():
8 | # Enable debug logging is not on for benchmarks
9 | aioesphomeapi_logger = logging.getLogger("aioesphomeapi")
10 | original_level = aioesphomeapi_logger.level
11 | aioesphomeapi_logger.setLevel(logging.WARNING)
12 | yield
13 | aioesphomeapi_logger.setLevel(original_level)
14 |
--------------------------------------------------------------------------------
/tests/benchmarks/test_bluetooth.py:
--------------------------------------------------------------------------------
1 | """Benchmarks."""
2 |
3 | from functools import partial
4 |
5 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped]
6 |
7 | from aioesphomeapi import APIConnection
8 | from aioesphomeapi._frame_helper.packets import _cached_varuint_to_bytes
9 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
10 | from aioesphomeapi.api_pb2 import (
11 | BluetoothLERawAdvertisement,
12 | BluetoothLERawAdvertisementsResponse,
13 | )
14 | from aioesphomeapi.client import APIClient
15 |
16 |
17 | async def test_raw_ble_plain_text_with_callback(benchmark: BenchmarkFixture) -> None:
18 | """Benchmark raw BLE plaintext with callback."""
19 |
20 | class MockConnection(APIConnection):
21 | pass
22 |
23 | client = APIClient("fake.address", 6052, None)
24 | connection = MockConnection(
25 | client._params, lambda expected_disconnect: None, False, None
26 | )
27 |
28 | process_incoming_msg = partial(
29 | connection.process_packet,
30 | 93,
31 | b'\n$\x08\xe3\x8a\x83\xad\x9c\xa3\x1d\x10\xbd\x01\x18\x01"\x15\x02\x01\x1a'
32 | b"\x02\n\x06\x0e\xffL\x00\x0f\x05\x90\x00\xb5B\x9c\x10\x02)\x04\n!"
33 | b'\x08\x9e\x9a\xb1\xfc\x9e\x890\x10\xbf\x01"\x14\x02\x01\x06\x10\xff\xa9\x0b'
34 | b"\x01\x05\x00\x0b\x04\x18\n\x1cM\x8c\xefI\xc0\n.\x08\x9f\x89\x85\xe6"
35 | b'\xf3\xe8\x17\x10\x8d\x01\x18\x01"\x1f\x02\x01\x02\x14\xff\xa7'
36 | b"\x05\x06\x00\x12 %\x00\xca\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x02\n"
37 | b'\x0f\x03\x03\x07\xfe\n\x1f\x08\x9e\xbf\xb5\x87\x98\xce7\x10_\x18\x01"'
38 | b"\x11\x02\x01\x06\r\xffi\t\xdeq\x80\xed_\x9e\x0bC\x08 \n \x08\xc6\x8a\xa9"
39 | b'\xed\xb9\xc4>\x10\xab\x01\x18\x01"\x11\x02\x01\x06\x07\xff\t\x04\x8c\x01'
40 | b'a\x01\x05\tRZSS\n \x08\xd7\xc6\xe8\xe8\x91\xb85\x10\xa5\x01\x18\x01"'
41 | b"\x11\x02\x01\x06\x07\xff\t\x04\x8c\x01`\x01\x05\tRZSS\n-\x08\xca\xb0\x91"
42 | b'\xf4\xbc\xe6<\x10}\x18\x01"\x1f\x02\x01\x04\x03\x03\x07\xfe\x14\xff\xa7'
43 | b"\x05\x06\x00\x12 %\x00\xca\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x02\n"
44 | b'\x00\n)\x08\xf9\xdd\x95\xac\xb9\x95\r\x10\x87\x01"\x1c\x02\x01\x06\x03'
45 | b"\x03\x12\x18\x10\tLOOKin_98F330B4\x03\x19\xc1\x03",
46 | )
47 |
48 | def on_advertisements(msgs: list[BluetoothLERawAdvertisement]):
49 | """Callback for advertisements."""
50 |
51 | connection.add_message_callback(
52 | on_advertisements,
53 | (BluetoothLERawAdvertisementsResponse,),
54 | )
55 |
56 | benchmark(process_incoming_msg)
57 |
58 |
59 | async def test_raw_ble_plain_text(benchmark: BenchmarkFixture) -> None:
60 | """Benchmark raw BLE plaintext."""
61 | adv = BluetoothLERawAdvertisementsResponse()
62 | fake_adv = BluetoothLERawAdvertisement(
63 | address=1,
64 | rssi=-86,
65 | address_type=2,
66 | data=(
67 | b"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75"
68 | b"657a2f686369302f64656c04010134000000e25389019500000001016f002500"
69 | b"00002f6f72672f626c75657a2f686369302f6465"
70 | ),
71 | )
72 | for i in range(5):
73 | adv.advertisements.append(fake_adv)
74 |
75 | type_ = 93
76 | data = adv.SerializeToString()
77 | data = (
78 | b"\0"
79 | + _cached_varuint_to_bytes(len(data))
80 | + _cached_varuint_to_bytes(type_)
81 | + data
82 | )
83 |
84 | class MockConnection(APIConnection):
85 | def __init__(self, *args, **kwargs):
86 | """Initialize the connection."""
87 |
88 | def process_packet(self, type_: int, data: bytes):
89 | """Process a packet."""
90 |
91 | def report_fatal_error(self, exc: Exception):
92 | raise exc
93 |
94 | connection = MockConnection()
95 |
96 | helper = APIPlaintextFrameHelper(
97 | connection=connection, client_info="my client", log_name="test"
98 | )
99 |
100 | process_incoming_msg = partial(helper.data_received, data)
101 |
102 | benchmark(process_incoming_msg)
103 |
104 |
105 | async def test_raw_ble_plain_text_different_advs(benchmark: BenchmarkFixture) -> None:
106 | """Benchmark raw BLE plaintext with different advertisements."""
107 | data = (
108 | b"\x01\x01\x07\x98\xaa7\xd7\xc5s\xe2\xdd\xc2\x96aG\xb1\xac:\xd3\xde"
109 | b"\x18\xefz\x00\xca@\xa9\xc8\xeb-\xe6`}\xa1\x00=\xae\x0e\xee\xc4Iy\xd6\x95"
110 | b"c\xed\x12S\xed\x14 \xa4\x9c&VcE\x0c=\xa8?\xaa\xe851\xdc=\xd6\xeeg\xffb"
111 | b"\x9a\xf5\xc9\xf6\x0b\r\xb9~\x11\xe3p$\xd9\xa9k\xcd\x1f\x03\x87f\xb8\x0c!\xac"
112 | b"\xb8:\xf5\x15jC@&\xf1\x13\xca\x89\x96r\xf9\xbd\xf1\xfe\xa0-\xfa\x87\x0cP"
113 | b"\xa7J+\xbaD,/\xf6\xc3\xf7\\\x1d\xcb#\xda@\xe0\n\xa7\xe0\xf0a\x16\xfb"
114 | b'\xb5\xfc\\\xbd1\xfb\xd25\x04\x94\x1e/"E\x90,J\xfd\x0f\xbc\xe5>\x96\xba'
115 | b"\x1bc\xa8\x1eQ\xbd|\xd9\xef\xc1\xffr\x04\x15i7\xea\x8clm`\xaa\x034"
116 | b"\x0b\xe5\xfe\x06\xfc\xb9\x9fc\xddE\xc93\xc0\x13\xe3\xe3$\xb1\xf2\x93"
117 | b"\xdb\x1dJ\xbf\x08edi.|\x93\x18\x7f\x83\x7fx\xbe\x01I\x1b\x8c\xe9\xf2\x06"
118 | b"\x8e\x08\xbe\xb0R&^7[\x1f4\x8f\xe0\xa1jf\xefL\x1b\x1el\xbb\x1c\x99"
119 | b"\x0f\x94r\xc2=\x10"
120 | )
121 |
122 | type_ = 93
123 | data = (
124 | b"\0"
125 | + _cached_varuint_to_bytes(len(data))
126 | + _cached_varuint_to_bytes(type_)
127 | + data
128 | )
129 |
130 | class MockConnection(APIConnection):
131 | def __init__(self, *args, **kwargs):
132 | """Initialize the connection."""
133 |
134 | def process_packet(self, type_: int, data: bytes):
135 | """Process a packet."""
136 |
137 | def report_fatal_error(self, exc: Exception):
138 | raise exc
139 |
140 | connection = MockConnection()
141 |
142 | helper = APIPlaintextFrameHelper(
143 | connection=connection, client_info="my client", log_name="test"
144 | )
145 |
146 | process_incoming_msg = partial(helper.data_received, data)
147 |
148 | benchmark(process_incoming_msg)
149 |
150 |
151 | async def test_multiple_ble_adv_messages_single_read(
152 | benchmark: BenchmarkFixture,
153 | ) -> None:
154 | """Benchmark multiple raw ble advertisement messages in a single read."""
155 | data = (
156 | b"\x01\x01\x07\x98\xaa7\xd7\xc5s\xe2\xdd\xc2\x96aG\xb1\xac:\xd3\xde"
157 | b"\x18\xefz\x00\xca@\xa9\xc8\xeb-\xe6`}\xa1\x00=\xae\x0e\xee\xc4Iy\xd6\x95"
158 | b"c\xed\x12S\xed\x14 \xa4\x9c&VcE\x0c=\xa8?\xaa\xe851\xdc=\xd6\xeeg\xffb"
159 | b"\x9a\xf5\xc9\xf6\x0b\r\xb9~\x11\xe3p$\xd9\xa9k\xcd\x1f\x03\x87f\xb8\x0c!\xac"
160 | b"\xb8:\xf5\x15jC@&\xf1\x13\xca\x89\x96r\xf9\xbd\xf1\xfe\xa0-\xfa\x87\x0cP"
161 | b"\xa7J+\xbaD,/\xf6\xc3\xf7\\\x1d\xcb#\xda@\xe0\n\xa7\xe0\xf0a\x16\xfb"
162 | b'\xb5\xfc\\\xbd1\xfb\xd25\x04\x94\x1e/"E\x90,J\xfd\x0f\xbc\xe5>\x96\xba'
163 | b"\x1bc\xa8\x1eQ\xbd|\xd9\xef\xc1\xffr\x04\x15i7\xea\x8clm`\xaa\x034"
164 | b"\x0b\xe5\xfe\x06\xfc\xb9\x9fc\xddE\xc93\xc0\x13\xe3\xe3$\xb1\xf2\x93"
165 | b"\xdb\x1dJ\xbf\x08edi.|\x93\x18\x7f\x83\x7fx\xbe\x01I\x1b\x8c\xe9\xf2\x06"
166 | b"\x8e\x08\xbe\xb0R&^7[\x1f4\x8f\xe0\xa1jf\xefL\x1b\x1el\xbb\x1c\x99"
167 | b"\x0f\x94r\xc2=\x10"
168 | )
169 |
170 | type_ = 93
171 | data = (
172 | b"\0"
173 | + _cached_varuint_to_bytes(len(data))
174 | + _cached_varuint_to_bytes(type_)
175 | + data
176 | )
177 |
178 | class MockConnection(APIConnection):
179 | def __init__(self, *args, **kwargs):
180 | """Initialize the connection."""
181 |
182 | def process_packet(self, type_: int, data: bytes):
183 | """Process a packet."""
184 |
185 | def report_fatal_error(self, exc: Exception):
186 | raise exc
187 |
188 | connection = MockConnection()
189 |
190 | helper = APIPlaintextFrameHelper(
191 | connection=connection, client_info="my client", log_name="test"
192 | )
193 |
194 | process_incoming_msg = partial(helper.data_received, data * 5)
195 |
196 | benchmark(process_incoming_msg)
197 |
--------------------------------------------------------------------------------
/tests/benchmarks/test_noise.py:
--------------------------------------------------------------------------------
1 | """Benchmarks for noise."""
2 |
3 | import asyncio
4 | import base64
5 | from collections.abc import Iterable
6 |
7 | import pytest
8 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped]
9 |
10 | from aioesphomeapi._frame_helper.noise_encryption import EncryptCipher
11 |
12 | from ..common import (
13 | MockAPINoiseFrameHelper,
14 | _extract_encrypted_payload_from_handshake,
15 | _make_encrypted_packet,
16 | _make_mock_connection,
17 | _make_noise_handshake_pkt,
18 | _make_noise_hello_pkt,
19 | _mock_responder_proto,
20 | mock_data_received,
21 | )
22 |
23 |
24 | @pytest.mark.parametrize("payload_size", [0, 64, 128, 1024, 16 * 1024])
25 | async def test_noise_messages(benchmark: BenchmarkFixture, payload_size: int) -> None:
26 | """Benchmark raw noise protocol."""
27 | noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
28 | psk_bytes = base64.b64decode(noise_psk)
29 | writes = []
30 |
31 | def _writelines(data: Iterable[bytes]):
32 | writes.append(b"".join(data))
33 |
34 | connection, packets = _make_mock_connection()
35 |
36 | helper = MockAPINoiseFrameHelper(
37 | connection=connection,
38 | noise_psk=noise_psk,
39 | expected_name="servicetest",
40 | expected_mac=None,
41 | client_info="my client",
42 | log_name="test",
43 | writer=_writelines,
44 | )
45 |
46 | proto = _mock_responder_proto(psk_bytes)
47 |
48 | await asyncio.sleep(0) # let the task run to read the hello packet
49 |
50 | assert len(writes) == 1
51 | handshake_pkt = writes.pop()
52 |
53 | encrypted_payload = _extract_encrypted_payload_from_handshake(handshake_pkt)
54 | decrypted = proto.read_message(encrypted_payload)
55 | assert decrypted == b""
56 |
57 | hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0")
58 | mock_data_received(helper, hello_pkt_with_header)
59 |
60 | handshake_with_header = _make_noise_handshake_pkt(proto)
61 | mock_data_received(helper, handshake_with_header)
62 |
63 | assert not writes
64 |
65 | await helper.ready_future
66 | helper.write_packets([(1, b"to device")], True)
67 | encrypted_packet = writes.pop()
68 | header = encrypted_packet[0:1]
69 | assert header == b"\x01"
70 | pkg_length_high = encrypted_packet[1]
71 | pkg_length_low = encrypted_packet[2]
72 | pkg_length = (pkg_length_high << 8) + pkg_length_low
73 | assert len(encrypted_packet) == 3 + pkg_length
74 |
75 | helper.write_packets([(1, b"to device")], True)
76 |
77 | def _empty_writelines(data: Iterable[bytes]):
78 | """Empty writelines."""
79 |
80 | helper._writelines = _empty_writelines
81 |
82 | payload = b"x" * payload_size
83 | encrypt_cipher = EncryptCipher(proto.noise_protocol.cipher_state_encrypt)
84 |
85 | @benchmark
86 | def process_encrypted_packets():
87 | for _ in range(100):
88 | helper.data_received(_make_encrypted_packet(encrypt_cipher, 42, payload))
89 |
90 | helper.close()
91 |
--------------------------------------------------------------------------------
/tests/benchmarks/test_requests.py:
--------------------------------------------------------------------------------
1 | """Benchmarks."""
2 |
3 | import asyncio
4 |
5 | from pytest_codspeed import BenchmarkFixture # type: ignore[import-untyped]
6 |
7 | from aioesphomeapi import APIConnection
8 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
9 | from aioesphomeapi.client import APIClient
10 |
11 |
12 | def test_sending_light_command_request_with_bool(
13 | benchmark: BenchmarkFixture,
14 | api_client: tuple[
15 | APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
16 | ],
17 | ) -> None:
18 | client, connection, transport, protocol = api_client
19 |
20 | connection._frame_helper._writelines = lambda lines: None
21 |
22 | @benchmark
23 | def send_request():
24 | client.light_command(1, True)
25 |
26 |
27 | def test_sending_empty_light_command_request(
28 | benchmark: BenchmarkFixture,
29 | api_client: tuple[
30 | APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
31 | ],
32 | ) -> None:
33 | client, connection, transport, protocol = api_client
34 |
35 | connection._frame_helper._writelines = lambda lines: None
36 |
37 | @benchmark
38 | def send_request():
39 | client.light_command(1)
40 |
--------------------------------------------------------------------------------
/tests/common.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from collections.abc import Awaitable
5 | from datetime import datetime, timezone
6 | from functools import partial
7 | import time
8 | from typing import Any, Callable
9 | from unittest.mock import AsyncMock, MagicMock, patch
10 |
11 | from google.protobuf import message
12 | from noise.connection import NoiseConnection # type: ignore[import-untyped]
13 | from zeroconf import Zeroconf
14 | from zeroconf.asyncio import AsyncZeroconf
15 |
16 | from aioesphomeapi import APIClient, APIConnection
17 | from aioesphomeapi._frame_helper.noise import APINoiseFrameHelper
18 | from aioesphomeapi._frame_helper.noise_encryption import (
19 | ESPHOME_NOISE_BACKEND,
20 | EncryptCipher,
21 | )
22 | from aioesphomeapi._frame_helper.packets import _cached_varuint_to_bytes
23 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
24 | from aioesphomeapi.api_pb2 import (
25 | ConnectResponse,
26 | HelloResponse,
27 | PingRequest,
28 | PingResponse,
29 | )
30 | from aioesphomeapi.client import ConnectionParams
31 | from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, SocketClosedAPIError
32 | from aioesphomeapi.zeroconf import ZeroconfManager
33 |
34 | UTC = timezone.utc
35 | _MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
36 | # We use a partial here since it is implemented in native code
37 | # and avoids the global lookup of UTC
38 | utcnow: partial[datetime] = partial(datetime.now, UTC)
39 | utcnow.__doc__ = "Get now in UTC time."
40 |
41 | PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
42 |
43 |
44 | PREAMBLE = b"\x00"
45 |
46 | NOISE_HELLO = b"\x01\x00\x00"
47 | KEEP_ALIVE_INTERVAL = 15.0
48 |
49 |
50 | def get_mock_connection_params() -> ConnectionParams:
51 | return ConnectionParams(
52 | addresses=["fake.address"],
53 | port=6052,
54 | password=None,
55 | client_info="Tests client",
56 | keepalive=KEEP_ALIVE_INTERVAL,
57 | zeroconf_manager=ZeroconfManager(),
58 | noise_psk=None,
59 | expected_name=None,
60 | expected_mac=None,
61 | )
62 |
63 |
64 | def mock_data_received(
65 | protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes
66 | ) -> None:
67 | """Mock data received on the protocol."""
68 | try:
69 | protocol.data_received(data)
70 | except Exception as err: # pylint: disable=broad-except
71 | loop = asyncio.get_running_loop()
72 | loop.call_soon(
73 | protocol.connection_lost,
74 | err,
75 | )
76 |
77 |
78 | def get_mock_zeroconf() -> MagicMock:
79 | with patch("zeroconf.Zeroconf.start"):
80 | zc = Zeroconf()
81 | zc.close = MagicMock()
82 | return zc
83 |
84 |
85 | def get_mock_async_zeroconf() -> AsyncZeroconf:
86 | aiozc = AsyncZeroconf(zc=get_mock_zeroconf())
87 | aiozc.async_close = AsyncMock()
88 | return aiozc
89 |
90 |
91 | class Estr(str):
92 | """A subclassed string."""
93 |
94 | __slots__ = ()
95 |
96 |
97 | def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
98 | type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
99 | bytes_ = msg.SerializeToString()
100 | return [
101 | b"\0",
102 | _cached_varuint_to_bytes(len(bytes_)),
103 | _cached_varuint_to_bytes(type_),
104 | bytes_,
105 | ]
106 |
107 |
108 | def generate_plaintext_packet(msg: message.Message) -> bytes:
109 | return b"".join(generate_split_plaintext_packet(msg))
110 |
111 |
112 | def as_utc(dattim: datetime) -> datetime:
113 | """Return a datetime as UTC time."""
114 | if dattim.tzinfo == UTC:
115 | return dattim
116 | return dattim.astimezone(UTC)
117 |
118 |
119 | def async_fire_time_changed(
120 | datetime_: datetime | None = None, fire_all: bool = False
121 | ) -> None:
122 | """Fire a time changed event at an exact microsecond.
123 |
124 | Consider that it is not possible to actually achieve an exact
125 | microsecond in production as the event loop is not precise enough.
126 | If your code relies on this level of precision, consider a different
127 | approach, as this is only for testing.
128 | """
129 | loop = asyncio.get_running_loop()
130 | utc_datetime = datetime.now(UTC) if datetime_ is None else as_utc(datetime_)
131 |
132 | timestamp = utc_datetime.timestamp()
133 | for task in list(loop._scheduled):
134 | if not isinstance(task, asyncio.TimerHandle):
135 | continue
136 | if task.cancelled():
137 | continue
138 |
139 | mock_seconds_into_future = timestamp - time.time()
140 | future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION)
141 |
142 | if fire_all or mock_seconds_into_future >= future_seconds:
143 | task._run()
144 | task.cancel()
145 |
146 |
147 | async def connect(conn: APIConnection, login: bool = True):
148 | """Wrapper for connection logic to do both parts."""
149 | await conn.start_resolve_host()
150 | await conn.start_connection()
151 | await conn.finish_connection(login=login)
152 |
153 |
154 | async def connect_client(
155 | client: APIClient,
156 | login: bool = True,
157 | on_stop: Callable[[bool], Awaitable[None]] | None = None,
158 | ) -> None:
159 | """Wrapper for connection logic to do both parts."""
160 | await client.start_resolve_host(on_stop=on_stop)
161 | await client.start_connection()
162 | await client.finish_connection(login=login)
163 |
164 |
165 | def send_plaintext_hello(
166 | protocol: APIPlaintextFrameHelper,
167 | major: int | None = None,
168 | minor: int | None = None,
169 | ) -> None:
170 | hello_response: message.Message = HelloResponse()
171 | hello_response.api_version_major = 1 if major is None else major
172 | hello_response.api_version_minor = 9 if minor is None else minor
173 | hello_response.name = "fake"
174 | protocol.data_received(generate_plaintext_packet(hello_response))
175 |
176 |
177 | def send_plaintext_connect_response(
178 | protocol: APIPlaintextFrameHelper, invalid_password: bool
179 | ) -> None:
180 | connect_response: message.Message = ConnectResponse()
181 | connect_response.invalid_password = invalid_password
182 | protocol.data_received(generate_plaintext_packet(connect_response))
183 |
184 |
185 | def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
186 | ping_response: message.Message = PingResponse()
187 | protocol.data_received(generate_plaintext_packet(ping_response))
188 |
189 |
190 | def send_ping_request(protocol: APIPlaintextFrameHelper) -> None:
191 | ping_request: message.Message = PingRequest()
192 | protocol.data_received(generate_plaintext_packet(ping_request))
193 |
194 |
195 | def get_mock_protocol(conn: APIConnection):
196 | protocol = APIPlaintextFrameHelper(
197 | connection=conn,
198 | client_info="mock",
199 | log_name="mock_device",
200 | )
201 | transport = MagicMock()
202 | protocol.connection_made(transport)
203 | return protocol
204 |
205 |
206 | def _create_mock_transport_protocol(
207 | transport: asyncio.Transport,
208 | connected: asyncio.Event,
209 | create_func: Callable[[], APIPlaintextFrameHelper],
210 | **kwargs,
211 | ) -> tuple[asyncio.Transport, APIPlaintextFrameHelper]:
212 | protocol: APIPlaintextFrameHelper = create_func()
213 | protocol.connection_made(transport)
214 | connected.set()
215 | return transport, protocol
216 |
217 |
218 | def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes:
219 | noise_hello = handshake_pkt[0:3]
220 | pkt_header = handshake_pkt[3:6]
221 | assert noise_hello == NOISE_HELLO
222 | assert pkt_header[0] == 1 # type
223 | pkg_length_high = pkt_header[1]
224 | pkg_length_low = pkt_header[2]
225 | pkg_length = (pkg_length_high << 8) + pkg_length_low
226 | assert pkg_length == 49
227 | noise_prefix = handshake_pkt[6:7]
228 | assert noise_prefix == b"\x00"
229 | return handshake_pkt[7:]
230 |
231 |
232 | def _make_noise_hello_pkt(hello_pkt: bytes) -> bytes:
233 | """Make a noise hello packet."""
234 | preamble = 1
235 | hello_pkg_length = len(hello_pkt)
236 | hello_pkg_length_high = (hello_pkg_length >> 8) & 0xFF
237 | hello_pkg_length_low = hello_pkg_length & 0xFF
238 | hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
239 | return hello_header + hello_pkt
240 |
241 |
242 | def _make_noise_handshake_pkt(proto: NoiseConnection) -> bytes:
243 | handshake = proto.write_message(b"")
244 | handshake_pkt = b"\x00" + handshake
245 | preamble = 1
246 | handshake_pkg_length = len(handshake_pkt)
247 | handshake_pkg_length_high = (handshake_pkg_length >> 8) & 0xFF
248 | handshake_pkg_length_low = handshake_pkg_length & 0xFF
249 | handshake_header = bytes(
250 | (preamble, handshake_pkg_length_high, handshake_pkg_length_low)
251 | )
252 |
253 | return handshake_header + handshake_pkt
254 |
255 |
256 | def _make_encrypted_packet(
257 | cipher: EncryptCipher, msg_type: int, payload: bytes
258 | ) -> bytes:
259 | msg_type = 42
260 | msg_type_high = (msg_type >> 8) & 0xFF
261 | msg_type_low = msg_type & 0xFF
262 | msg_length = len(payload)
263 | msg_length_high = (msg_length >> 8) & 0xFF
264 | msg_length_low = msg_length & 0xFF
265 | msg_header = bytes((msg_type_high, msg_type_low, msg_length_high, msg_length_low))
266 | encrypted_payload = cipher.encrypt(msg_header + payload)
267 | return _make_encrypted_packet_from_encrypted_payload(encrypted_payload)
268 |
269 |
270 | def _make_encrypted_packet_from_encrypted_payload(encrypted_payload: bytes) -> bytes:
271 | preamble = 1
272 | encrypted_pkg_length = len(encrypted_payload)
273 | encrypted_pkg_length_high = (encrypted_pkg_length >> 8) & 0xFF
274 | encrypted_pkg_length_low = encrypted_pkg_length & 0xFF
275 | encrypted_header = bytes(
276 | (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low)
277 | )
278 | return encrypted_header + encrypted_payload
279 |
280 |
281 | def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection:
282 | proto = NoiseConnection.from_name(
283 | b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
284 | )
285 | proto.set_as_responder()
286 | proto.set_psks(psk_bytes)
287 | proto.set_prologue(b"NoiseAPIInit\x00\x00")
288 | proto.start_handshake()
289 | return proto
290 |
291 |
292 | def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
293 | """Make a mock connection."""
294 | packets: list[tuple[int, bytes]] = []
295 |
296 | class MockConnection(APIConnection):
297 | def __init__(self, *args: Any, **kwargs: Any) -> None:
298 | """Swallow args."""
299 | super().__init__(
300 | get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
301 | )
302 |
303 | def process_packet(self, type_: int, data: bytes):
304 | packets.append((type_, data))
305 |
306 | connection = MockConnection()
307 | return connection, packets
308 |
309 |
310 | class MockAPINoiseFrameHelper(APINoiseFrameHelper):
311 | def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None:
312 | """Swallow args."""
313 | super().__init__(*args, **kwargs)
314 | transport = MagicMock()
315 | transport.writelines = writer or MagicMock()
316 | self.__transport = transport
317 | self.connection_made(transport)
318 |
319 | def connection_made(self, transport: Any) -> None:
320 | return super().connection_made(self.__transport)
321 |
322 | def mock_write_frame(self, frame: bytes) -> None:
323 | """Write a packet to the socket.
324 |
325 | The entire packet must be written in a single call to write.
326 | """
327 | frame_len = len(frame)
328 | header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
329 | try:
330 | self._writelines([header, frame])
331 | except (RuntimeError, ConnectionResetError, OSError) as err:
332 | raise SocketClosedAPIError(
333 | f"{self._log_name}: Error while writing data: {err}"
334 | ) from err
335 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Test fixtures."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | from collections.abc import Generator
7 | from contextlib import contextmanager
8 | from dataclasses import replace
9 | from functools import partial
10 | import reprlib
11 | import socket
12 | from unittest.mock import AsyncMock, MagicMock, create_autospec, patch
13 |
14 | import pytest
15 | import pytest_asyncio
16 |
17 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
18 | from aioesphomeapi.client import APIClient, ConnectionParams
19 | from aioesphomeapi.connection import APIConnection
20 | from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
21 |
22 | from .common import (
23 | _create_mock_transport_protocol,
24 | connect,
25 | connect_client,
26 | get_mock_async_zeroconf,
27 | get_mock_connection_params,
28 | send_plaintext_hello,
29 | )
30 |
31 | _MOCK_RESOLVE_RESULT = [
32 | AddrInfo(
33 | family=socket.AF_INET,
34 | type=socket.SOCK_STREAM,
35 | proto=socket.IPPROTO_TCP,
36 | sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
37 | )
38 | ]
39 |
40 |
41 | class PatchableAPIConnection(APIConnection):
42 | """Patchable APIConnection for testing."""
43 |
44 |
45 | class PatchableAPIClient(APIClient):
46 | """Patchable APIClient for testing."""
47 |
48 |
49 | @pytest.fixture
50 | def async_zeroconf():
51 | return get_mock_async_zeroconf()
52 |
53 |
54 | @pytest.fixture
55 | def resolve_host() -> Generator[AsyncMock]:
56 | with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
57 | func.return_value = _MOCK_RESOLVE_RESULT
58 | yield func
59 |
60 |
61 | @pytest.fixture
62 | async def patchable_api_client() -> APIClient:
63 | cli = PatchableAPIClient(
64 | address="127.0.0.1",
65 | port=6052,
66 | password=None,
67 | )
68 | return cli
69 |
70 |
71 | @pytest.fixture
72 | async def auth_client():
73 | client = PatchableAPIClient(
74 | address="fake.address",
75 | port=6052,
76 | password=None,
77 | )
78 | mock_connection = PatchableAPIConnection(
79 | params=client._params,
80 | on_stop=client._on_stop,
81 | debug_enabled=False,
82 | log_name=client.log_name,
83 | )
84 | mock_connection.is_connected = True
85 | with patch.object(client, "_connection", mock_connection):
86 | yield client
87 |
88 |
89 | @pytest.fixture
90 | def connection_params(event_loop: asyncio.AbstractEventLoop) -> ConnectionParams:
91 | return get_mock_connection_params()
92 |
93 |
94 | def mock_on_stop(expected_disconnect: bool) -> None:
95 | pass
96 |
97 |
98 | @pytest.fixture
99 | async def conn(connection_params: ConnectionParams) -> APIConnection:
100 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
101 |
102 |
103 | @pytest.fixture
104 | async def conn_with_password(connection_params: ConnectionParams) -> APIConnection:
105 | connection_params = replace(connection_params, password="password")
106 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
107 |
108 |
109 | @pytest.fixture
110 | async def noise_conn(connection_params: ConnectionParams) -> APIConnection:
111 | connection_params = replace(
112 | connection_params, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
113 | )
114 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
115 |
116 |
117 | @pytest.fixture
118 | async def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnection:
119 | connection_params = replace(connection_params, expected_name="test")
120 | return PatchableAPIConnection(connection_params, mock_on_stop, True, None)
121 |
122 |
123 | @pytest.fixture()
124 | async def aiohappyeyeballs_start_connection():
125 | with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
126 | mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
127 | mock_socket.type = socket.SOCK_STREAM
128 | mock_socket.fileno.return_value = 1
129 | mock_socket.getpeername.return_value = ("10.0.0.512", 323)
130 | func.return_value = mock_socket
131 | yield func
132 |
133 |
134 | @pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
135 | async def plaintext_connect_task_no_login(
136 | conn: APIConnection,
137 | resolve_host,
138 | aiohappyeyeballs_start_connection,
139 | ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
140 | loop = asyncio.get_running_loop()
141 | transport = MagicMock()
142 | connected = asyncio.Event()
143 |
144 | with patch.object(
145 | loop,
146 | "create_connection",
147 | side_effect=partial(_create_mock_transport_protocol, transport, connected),
148 | ):
149 | connect_task = asyncio.create_task(connect(conn, login=False))
150 | await connected.wait()
151 | yield conn, transport, conn._frame_helper, connect_task
152 | conn.force_disconnect()
153 |
154 |
155 | @pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
156 | async def plaintext_connect_task_no_login_with_expected_name(
157 | conn_with_expected_name: APIConnection,
158 | resolve_host,
159 | aiohappyeyeballs_start_connection,
160 | ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
161 | event_loop = asyncio.get_running_loop()
162 | transport = MagicMock()
163 | connected = asyncio.Event()
164 |
165 | with patch.object(
166 | event_loop,
167 | "create_connection",
168 | side_effect=partial(_create_mock_transport_protocol, transport, connected),
169 | ):
170 | connect_task = asyncio.create_task(
171 | connect(conn_with_expected_name, login=False)
172 | )
173 | await connected.wait()
174 | yield (
175 | conn_with_expected_name,
176 | transport,
177 | conn_with_expected_name._frame_helper,
178 | connect_task,
179 | )
180 | conn_with_expected_name.force_disconnect()
181 |
182 |
183 | @pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
184 | async def plaintext_connect_task_with_login(
185 | conn_with_password: APIConnection,
186 | resolve_host,
187 | aiohappyeyeballs_start_connection,
188 | ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
189 | transport = MagicMock()
190 | connected = asyncio.Event()
191 | event_loop = asyncio.get_running_loop()
192 |
193 | with patch.object(
194 | event_loop,
195 | "create_connection",
196 | side_effect=partial(_create_mock_transport_protocol, transport, connected),
197 | ):
198 | connect_task = asyncio.create_task(connect(conn_with_password, login=True))
199 | await connected.wait()
200 | yield (
201 | conn_with_password,
202 | transport,
203 | conn_with_password._frame_helper,
204 | connect_task,
205 | )
206 | conn_with_password.force_disconnect()
207 |
208 |
209 | @pytest_asyncio.fixture(name="api_client")
210 | async def api_client(
211 | resolve_host, aiohappyeyeballs_start_connection
212 | ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
213 | event_loop = asyncio.get_running_loop()
214 | protocol: APIPlaintextFrameHelper | None = None
215 | transport = MagicMock()
216 | connected = asyncio.Event()
217 | client = APIClient(
218 | address="mydevice.local",
219 | port=6052,
220 | password=None,
221 | )
222 |
223 | with (
224 | patch.object(
225 | event_loop,
226 | "create_connection",
227 | side_effect=partial(_create_mock_transport_protocol, transport, connected),
228 | ),
229 | patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection),
230 | ):
231 | connect_task = asyncio.create_task(connect_client(client, login=False))
232 | await connected.wait()
233 | conn = client._connection
234 | protocol = conn._frame_helper
235 | send_plaintext_hello(protocol)
236 | await connect_task
237 | transport.reset_mock()
238 | yield client, conn, transport, protocol
239 | conn.force_disconnect()
240 |
241 |
242 | def get_scheduled_timer_handles(
243 | loop: asyncio.AbstractEventLoop,
244 | ) -> list[asyncio.TimerHandle]:
245 | """Return a list of scheduled TimerHandles."""
246 | handles: list[asyncio.TimerHandle] = loop._scheduled # type: ignore[attr-defined]
247 | return handles
248 |
249 |
250 | @contextmanager
251 | def long_repr_strings() -> Generator[None]:
252 | """Increase reprlib maxstring and maxother to 300."""
253 | arepr = reprlib.aRepr
254 | original_maxstring = arepr.maxstring
255 | original_maxother = arepr.maxother
256 | arepr.maxstring = 300
257 | arepr.maxother = 300
258 | try:
259 | yield
260 | finally:
261 | arepr.maxstring = original_maxstring
262 | arepr.maxother = original_maxother
263 |
264 |
265 | @pytest.fixture(autouse=True)
266 | def verify_no_lingering_tasks(
267 | event_loop: asyncio.AbstractEventLoop,
268 | ) -> Generator[None]:
269 | """Verify that all tasks are cleaned up."""
270 | tasks_before = asyncio.all_tasks(event_loop)
271 | yield
272 |
273 | tasks = asyncio.all_tasks(event_loop) - tasks_before
274 | for task in tasks:
275 | pytest.fail(f"Task still running: {task!r}")
276 | task.cancel()
277 | if tasks:
278 | event_loop.run_until_complete(asyncio.wait(tasks))
279 |
280 | for handle in get_scheduled_timer_handles(event_loop):
281 | if not handle.cancelled():
282 | with long_repr_strings():
283 | pytest.fail(f"Lingering timer after test {handle!r}")
284 | handle.cancel()
285 |
--------------------------------------------------------------------------------
/tests/test_core.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
4 |
5 |
6 | def test_order_and_no_missing_numbers_in_message_type_to_proto():
7 | """Test that MESSAGE_TYPE_TO_PROTO has no missing numbers."""
8 | for idx, (k, v) in enumerate(MESSAGE_TYPE_TO_PROTO.items()):
9 | assert idx + 1 == k
10 |
--------------------------------------------------------------------------------
/tests/test_log_runner.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from datetime import timedelta
5 | from functools import partial
6 | from unittest.mock import MagicMock, patch
7 |
8 | from google.protobuf import message
9 | import pytest
10 |
11 | from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
12 | from aioesphomeapi.api_pb2 import (
13 | DisconnectRequest,
14 | DisconnectResponse,
15 | SubscribeLogsResponse, # type: ignore
16 | )
17 | from aioesphomeapi.client import APIClient
18 | from aioesphomeapi.connection import APIConnection
19 | from aioesphomeapi.core import APIConnectionError
20 | from aioesphomeapi.log_runner import async_run
21 | from aioesphomeapi.reconnect_logic import EXPECTED_DISCONNECT_COOLDOWN
22 |
23 | from .common import (
24 | Estr,
25 | async_fire_time_changed,
26 | generate_plaintext_packet,
27 | get_mock_async_zeroconf,
28 | mock_data_received,
29 | send_plaintext_connect_response,
30 | send_plaintext_hello,
31 | utcnow,
32 | )
33 |
34 |
35 | async def test_log_runner(
36 | conn: APIConnection,
37 | aiohappyeyeballs_start_connection,
38 | ):
39 | """Test the log runner logic."""
40 | loop = asyncio.get_running_loop()
41 | protocol: APIPlaintextFrameHelper | None = None
42 | transport = MagicMock()
43 | connected = asyncio.Event()
44 |
45 | class PatchableAPIClient(APIClient):
46 | pass
47 |
48 | async_zeroconf = get_mock_async_zeroconf()
49 |
50 | cli = PatchableAPIClient(
51 | address=Estr("127.0.0.1"),
52 | port=6052,
53 | password=None,
54 | noise_psk=None,
55 | expected_name=Estr("fake"),
56 | zeroconf_instance=async_zeroconf.zeroconf,
57 | )
58 | messages = []
59 |
60 | def on_log(msg: SubscribeLogsResponse) -> None:
61 | messages.append(msg)
62 |
63 | def _create_mock_transport_protocol(create_func, **kwargs):
64 | nonlocal protocol
65 | protocol = create_func()
66 | protocol.connection_made(transport)
67 | connected.set()
68 | return transport, protocol
69 |
70 | subscribed = asyncio.Event()
71 | original_subscribe_logs = cli.subscribe_logs
72 |
73 | def _wait_subscribe_cli(*args, **kwargs):
74 | original_subscribe_logs(*args, **kwargs)
75 | subscribed.set()
76 |
77 | with (
78 | patch.object(
79 | loop, "create_connection", side_effect=_create_mock_transport_protocol
80 | ),
81 | patch.object(cli, "subscribe_logs", _wait_subscribe_cli),
82 | ):
83 | stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
84 | await connected.wait()
85 | protocol = cli._connection._frame_helper
86 | send_plaintext_hello(protocol)
87 | send_plaintext_connect_response(protocol, False)
88 | await subscribed.wait()
89 |
90 | response: message.Message = SubscribeLogsResponse()
91 | response.message = b"Hello world"
92 | mock_data_received(protocol, generate_plaintext_packet(response))
93 | assert len(messages) == 1
94 | assert messages[0].message == b"Hello world"
95 | stop_task = asyncio.create_task(stop())
96 | await asyncio.sleep(0)
97 | disconnect_response = DisconnectResponse()
98 | mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
99 | await stop_task
100 |
101 |
102 | async def test_log_runner_reconnects_on_disconnect(
103 | conn: APIConnection,
104 | caplog: pytest.LogCaptureFixture,
105 | aiohappyeyeballs_start_connection,
106 | ) -> None:
107 | """Test the log runner reconnects on disconnect."""
108 | loop = asyncio.get_running_loop()
109 | protocol: APIPlaintextFrameHelper | None = None
110 | transport = MagicMock()
111 | connected = asyncio.Event()
112 |
113 | class PatchableAPIClient(APIClient):
114 | pass
115 |
116 | async_zeroconf = get_mock_async_zeroconf()
117 |
118 | cli = PatchableAPIClient(
119 | address=Estr("127.0.0.1"),
120 | port=6052,
121 | password=None,
122 | noise_psk=None,
123 | expected_name=Estr("fake"),
124 | zeroconf_instance=async_zeroconf.zeroconf,
125 | )
126 | messages = []
127 |
128 | def on_log(msg: SubscribeLogsResponse) -> None:
129 | messages.append(msg)
130 |
131 | def _create_mock_transport_protocol(create_func, **kwargs):
132 | nonlocal protocol
133 | protocol = create_func()
134 | protocol.connection_made(transport)
135 | connected.set()
136 | return transport, protocol
137 |
138 | subscribed = asyncio.Event()
139 | original_subscribe_logs = cli.subscribe_logs
140 |
141 | def _wait_subscribe_cli(*args, **kwargs):
142 | original_subscribe_logs(*args, **kwargs)
143 | subscribed.set()
144 |
145 | with (
146 | patch.object(
147 | loop, "create_connection", side_effect=_create_mock_transport_protocol
148 | ),
149 | patch.object(cli, "subscribe_logs", _wait_subscribe_cli),
150 | ):
151 | stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
152 | await connected.wait()
153 | protocol = cli._connection._frame_helper
154 | send_plaintext_hello(protocol)
155 | send_plaintext_connect_response(protocol, False)
156 | await subscribed.wait()
157 |
158 | response: message.Message = SubscribeLogsResponse()
159 | response.message = b"Hello world"
160 | mock_data_received(protocol, generate_plaintext_packet(response))
161 | assert len(messages) == 1
162 | assert messages[0].message == b"Hello world"
163 |
164 | with patch.object(cli, "start_resolve_host") as mock_start_resolve_host:
165 | response: message.Message = DisconnectRequest()
166 | mock_data_received(protocol, generate_plaintext_packet(response))
167 |
168 | await asyncio.sleep(0)
169 | assert cli._connection is None
170 | async_fire_time_changed(
171 | utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN)
172 | )
173 | await asyncio.sleep(0)
174 |
175 | assert "Disconnected from API" in caplog.text
176 | assert mock_start_resolve_host.called
177 |
178 | await stop()
179 |
180 |
181 | async def test_log_runner_reconnects_on_subscribe_failure(
182 | conn: APIConnection,
183 | caplog: pytest.LogCaptureFixture,
184 | aiohappyeyeballs_start_connection,
185 | ) -> None:
186 | """Test the log runner reconnects on subscribe failure."""
187 | loop = asyncio.get_running_loop()
188 | protocol: APIPlaintextFrameHelper | None = None
189 | transport = MagicMock()
190 | connected = asyncio.Event()
191 |
192 | class PatchableAPIClient(APIClient):
193 | pass
194 |
195 | async_zeroconf = get_mock_async_zeroconf()
196 |
197 | cli = PatchableAPIClient(
198 | address=Estr("127.0.0.1"),
199 | port=6052,
200 | password=None,
201 | noise_psk=None,
202 | expected_name=Estr("fake"),
203 | zeroconf_instance=async_zeroconf.zeroconf,
204 | )
205 | messages = []
206 |
207 | def on_log(msg: SubscribeLogsResponse) -> None:
208 | messages.append(msg)
209 |
210 | def _create_mock_transport_protocol(create_func, **kwargs):
211 | nonlocal protocol
212 | protocol = create_func()
213 | protocol.connection_made(transport)
214 | connected.set()
215 | return transport, protocol
216 |
217 | subscribed = asyncio.Event()
218 |
219 | def _wait_and_fail_subscribe_cli(*args, **kwargs):
220 | subscribed.set()
221 | raise APIConnectionError("subscribed force to fail")
222 |
223 | with (
224 | patch.object(cli, "disconnect", partial(cli.disconnect, force=True)),
225 | patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli),
226 | ):
227 | with patch.object(
228 | loop, "create_connection", side_effect=_create_mock_transport_protocol
229 | ):
230 | stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf)
231 | await connected.wait()
232 | protocol = cli._connection._frame_helper
233 | send_plaintext_hello(protocol)
234 | send_plaintext_connect_response(protocol, False)
235 |
236 | await subscribed.wait()
237 |
238 | assert cli._connection is None
239 |
240 | with (
241 | patch.object(
242 | loop, "create_connection", side_effect=_create_mock_transport_protocol
243 | ),
244 | patch.object(cli, "subscribe_logs"),
245 | ):
246 | connected.clear()
247 | await asyncio.sleep(0)
248 | async_fire_time_changed(
249 | utcnow() + timedelta(seconds=EXPECTED_DISCONNECT_COOLDOWN)
250 | )
251 | await asyncio.sleep(0)
252 |
253 | stop_task = asyncio.create_task(stop())
254 | await asyncio.sleep(0)
255 |
256 | send_plaintext_connect_response(protocol, False)
257 | send_plaintext_hello(protocol)
258 |
259 | disconnect_response = DisconnectResponse()
260 | mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
261 |
262 | await stop_task
263 |
--------------------------------------------------------------------------------
/tests/test_util.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import math
3 | import sys
4 |
5 | import pytest
6 |
7 | from aioesphomeapi import util
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "input, output",
12 | [
13 | (0, 0),
14 | (float("inf"), float("inf")),
15 | (float("-inf"), float("-inf")),
16 | (0.1, 0.1),
17 | (-0.0, -0.0),
18 | (0.10000000149011612, 0.1),
19 | (1, 1),
20 | (-1, -1),
21 | (-0.10000000149011612, -0.1),
22 | (-152198557936981706463557226105667584, -152198600000000000000000000000000000),
23 | (-0.0030539485160261, -0.003053949),
24 | (0.5, 0.5),
25 | (0.0000000000000019, 0.0000000000000019),
26 | ],
27 | )
28 | def test_fix_float_single_double_conversion(input, output):
29 | assert util.fix_float_single_double_conversion(input) == output
30 |
31 |
32 | def test_fix_float_single_double_conversion_nan():
33 | assert math.isnan(util.fix_float_single_double_conversion(float("nan")))
34 |
35 |
36 | @pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+")
37 | async def test_create_eager_task_312() -> None:
38 | """Test create_eager_task schedules a task eagerly in the event loop.
39 |
40 | For Python 3.12+, the task is scheduled eagerly in the event loop.
41 | """
42 | events = []
43 |
44 | async def _normal_task():
45 | events.append("normal")
46 |
47 | async def _eager_task():
48 | events.append("eager")
49 |
50 | task1 = util.create_eager_task(_eager_task())
51 | task2 = asyncio.create_task(_normal_task())
52 |
53 | assert events == ["eager"]
54 |
55 | await asyncio.sleep(0)
56 | assert events == ["eager", "normal"]
57 | await task1
58 | await task2
59 |
60 |
61 | @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12")
62 | async def test_create_eager_task_pre_312() -> None:
63 | """Test create_eager_task schedules a task in the event loop.
64 |
65 | For older python versions, the task is scheduled normally.
66 | """
67 | events = []
68 |
69 | async def _normal_task():
70 | events.append("normal")
71 |
72 | async def _eager_task():
73 | events.append("eager")
74 |
75 | task1 = util.create_eager_task(_eager_task())
76 | task2 = asyncio.create_task(_normal_task())
77 |
78 | assert events == []
79 |
80 | await asyncio.sleep(0)
81 | assert events == ["eager", "normal"]
82 | await task1
83 | await task2
84 |
--------------------------------------------------------------------------------
/tests/test_zeroconf.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from unittest.mock import patch
4 |
5 | import pytest
6 | from zeroconf.asyncio import AsyncZeroconf
7 |
8 | from aioesphomeapi.zeroconf import ZeroconfManager
9 |
10 | from .common import get_mock_async_zeroconf
11 |
12 |
13 | async def test_does_not_closed_passed_in_async_instance(async_zeroconf: AsyncZeroconf):
14 | """Test that the passed in instance is not closed."""
15 | manager = ZeroconfManager()
16 | manager.set_instance(async_zeroconf)
17 | await manager.async_close()
18 | assert async_zeroconf.async_close.call_count == 0
19 |
20 |
21 | async def test_does_not_closed_passed_in_sync_instance(async_zeroconf: AsyncZeroconf):
22 | """Test that the passed in instance is not closed."""
23 | manager = ZeroconfManager()
24 | manager.set_instance(async_zeroconf.zeroconf)
25 | await manager.async_close()
26 | assert async_zeroconf.async_close.call_count == 0
27 |
28 |
29 | async def test_closes_created_instance(async_zeroconf: AsyncZeroconf):
30 | """Test that the created instance is closed."""
31 | with patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
32 | manager = ZeroconfManager()
33 | assert manager.get_async_zeroconf() is async_zeroconf
34 | await manager.async_close()
35 | assert async_zeroconf.async_close.call_count == 1
36 |
37 |
38 | async def test_runtime_error_multiple_instances(async_zeroconf: AsyncZeroconf):
39 | """Test runtime error is raised on multiple instances."""
40 | manager = ZeroconfManager(async_zeroconf)
41 | new_instance = get_mock_async_zeroconf()
42 | with pytest.raises(RuntimeError):
43 | manager.set_instance(new_instance)
44 | manager.set_instance(async_zeroconf)
45 | manager.set_instance(async_zeroconf.zeroconf)
46 | manager.set_instance(async_zeroconf)
47 | await manager.async_close()
48 | assert async_zeroconf.async_close.call_count == 0
49 |
--------------------------------------------------------------------------------