├── .coveragerc ├── .cruft.json ├── .github ├── renovate.json5 └── workflows │ ├── cruft.yaml │ ├── lint.yaml │ ├── pages.yaml │ ├── publish.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .ruff.toml ├── .yaml-lint.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── google_nest_sdm ├── __init__.py ├── admin_client.py ├── auth.py ├── camera_traits.py ├── device.py ├── device_manager.py ├── device_traits.py ├── diagnostics.py ├── doorbell_traits.py ├── event.py ├── event_media.py ├── exceptions.py ├── google_nest.py ├── google_nest_api.py ├── google_nest_subscriber.py ├── model.py ├── py.typed ├── registry.py ├── streaming_manager.py ├── structure.py ├── subscriber_client.py ├── thermostat_traits.py ├── traits.py ├── transcoder.py └── webrtc_util.py ├── mypy.ini ├── pyproject.toml ├── pytest.ini ├── requirements_dev.txt ├── script └── run-mypy.sh ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── test_admin_client.py ├── test_auth.py ├── test_camera_traits.py ├── test_device.py ├── test_device_manager.py ├── test_device_traits.py ├── test_doorbell_traits.py ├── test_event.py ├── test_event_media.py ├── test_google_nest_api.py ├── test_google_nest_subscriber.py ├── test_streaming_manager.py ├── test_structure.py ├── test_subscriber_client.py ├── test_thermostat_traits.py ├── test_transcoder.py └── test_webrtc_util.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=google_nest_sdm 3 | 4 | omit = 5 | google_nest_sdm/google_nest.py 6 | -------------------------------------------------------------------------------- /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/allenporter/cookiecutter-python", 3 | "commit": "da08492f3d4d2c7bd2d4600cda2b508bc75db3a2", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "full_name": "Allen Porter", 8 | "email": "allen.porter@gmail.com", 9 | "github_username": "allenporter", 10 | "project_name": "google_nest_sdm", 11 | "description": "Library for the Google Nest SDM API", 12 | "version": "3.0.4", 13 | "_template": "https://github.com/allenporter/cookiecutter-python" 14 | } 15 | }, 16 | "directory": null 17 | } 18 | -------------------------------------------------------------------------------- /.github/renovate.json5: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:recommended" 5 | ], 6 | "assignees": ["allenporter"], 7 | "packageRules": [ 8 | { 9 | "description": "Minor updates are automatic", 10 | "automerge": true, 11 | "automergeType": "branch", 12 | "matchUpdateTypes": ["minor", "patch"] 13 | } 14 | ], 15 | "pip_requirements": { 16 | "fileMatch": ["requirements_dev.txt"] 17 | }, 18 | "pre-commit": {"enabled": true} 19 | } 20 | -------------------------------------------------------------------------------- /.github/workflows/cruft.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Update repository with Cruft 3 | permissions: 4 | contents: write 5 | pull-requests: write 6 | actions: write 7 | on: 8 | schedule: 9 | - cron: "0 0 * * *" 10 | 11 | env: 12 | PYTHON_VERSION: 3.13 13 | 14 | jobs: 15 | update: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | fail-fast: true 19 | matrix: 20 | include: 21 | - add-paths: . 22 | body: Use this to merge the changes to this repository. 23 | branch: cruft/update 24 | commit-message: "chore: accept new Cruft update" 25 | title: New updates detected with Cruft 26 | - add-paths: .cruft.json 27 | body: Use this to reject the changes in this repository. 28 | branch: cruft/reject 29 | commit-message: "chore: reject new Cruft update" 30 | title: Reject new updates detected with Cruft 31 | steps: 32 | - uses: actions/checkout@v4 33 | - name: Set up Python 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ env.PYTHON_VERSION }} 37 | 38 | - name: Install Cruft 39 | run: pip3 install cruft 40 | 41 | - name: Check if update is available 42 | continue-on-error: false 43 | id: check 44 | run: | 45 | CHANGES=0 46 | if [ -f .cruft.json ]; then 47 | if ! cruft check; then 48 | CHANGES=1 49 | fi 50 | else 51 | echo "No .cruft.json file" 52 | fi 53 | 54 | echo "has_changes=$CHANGES" >> "$GITHUB_OUTPUT" 55 | 56 | - name: Run update if available 57 | if: steps.check.outputs.has_changes == '1' 58 | run: | 59 | git config --global user.email "allen.porter@gmail.com" 60 | git config --global user.name "Allen Porter" 61 | 62 | cruft update --skip-apply-ask --refresh-private-variables 63 | git restore --staged . 64 | 65 | 66 | - name: Create pull request 67 | if: steps.check.outputs.has_changes == '1' 68 | uses: peter-evans/create-pull-request@v7 69 | with: 70 | token: ${{ secrets.GITHUB_TOKEN }} 71 | add-paths: ${{ matrix.add-paths }} 72 | commit-message: ${{ matrix.commit-message }} 73 | branch: ${{ matrix.branch }} 74 | title: ${{ matrix.title }} 75 | body: | 76 | This is an autogenerated PR. ${{ matrix.body }} 77 | 78 | [Cruft](https://cruft.github.io/cruft/) has detected updates from the Cookiecutter repository. 79 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Lint 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | - renovate/** 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | env: 14 | PYTHON_VERSION: 3.13 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-latest 19 | strategy: 20 | fail-fast: false 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: chartboost/ruff-action@v1.0.0 25 | - uses: codespell-project/actions-codespell@v2.1 26 | with: 27 | check_hidden: false 28 | - name: Run yamllint 29 | uses: ibiqlik/action-yamllint@v3 30 | with: 31 | file_or_dir: "./" 32 | config_file: "./.yaml-lint.yaml" 33 | strict: true 34 | 35 | - name: Install uv 36 | uses: astral-sh/setup-uv@v6 37 | with: 38 | python-version: ${{ env.PYTHON_VERSION }} 39 | enable-cache: true 40 | cache-dependency-glob: "requirements_dev.txt" 41 | activate-environment: true 42 | - name: Install dependencies 43 | run: | 44 | uv pip install -r requirements_dev.txt 45 | 46 | - name: Static typing with mypy 47 | run: | 48 | mypy --install-types --non-interactive --no-warn-unused-ignores . 49 | -------------------------------------------------------------------------------- /.github/workflows/pages.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | 9 | workflow_dispatch: 10 | 11 | permissions: 12 | contents: read 13 | pages: write 14 | id-token: write 15 | actions: read 16 | 17 | concurrency: 18 | group: "pages" 19 | cancel-in-progress: true 20 | 21 | env: 22 | PYTHON_VERSION: 3.13 23 | 24 | jobs: 25 | deploy: 26 | environment: 27 | name: github-pages 28 | url: ${{ steps.deployment.outputs.page_url }} 29 | runs-on: ubuntu-latest 30 | strategy: 31 | fail-fast: false 32 | steps: 33 | - uses: actions/checkout@v4 34 | - name: Install uv 35 | uses: astral-sh/setup-uv@v6 36 | with: 37 | python-version: ${{ env.PYTHON_VERSION }} 38 | enable-cache: true 39 | cache-dependency-glob: "requirements_dev.txt" 40 | activate-environment: true 41 | - name: Install dependencies 42 | run: | 43 | uv pip install -r requirements_dev.txt 44 | - run: pdoc ./google_nest_sdm -o docs/ 45 | - name: Setup Pages 46 | uses: actions/configure-pages@v5 47 | - name: Upload artifact 48 | uses: actions/upload-pages-artifact@v3 49 | with: 50 | # Upload entire repository 51 | path: 'docs/' 52 | - name: Deploy to GitHub Pages 53 | id: deployment 54 | uses: actions/deploy-pages@v4 55 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Upload Python Package 3 | 4 | on: 5 | release: 6 | types: [created] 7 | 8 | env: 9 | PYTHON_VERSION: 3.13 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ env.PYTHON_VERSION }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build --user 24 | - name: Build a binary wheel and a source tarball 25 | run: python3 -m build 26 | - name: Store the distribution packages 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: python-package-distributions 30 | path: dist/ 31 | 32 | publish-to-pypi: 33 | name: >- 34 | Publish Python 🐍 distribution 📦 to PyPI 35 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 36 | needs: 37 | - build 38 | runs-on: ubuntu-latest 39 | environment: 40 | name: pypi 41 | url: https://pypi.org/p/google_nest_sdm 42 | permissions: 43 | id-token: write # IMPORTANT: mandatory for trusted publishing 44 | steps: 45 | - name: Download all the dists 46 | uses: actions/download-artifact@v4 47 | with: 48 | name: python-package-distributions 49 | path: dist/ 50 | - name: Publish distribution 📦 to PyPI 51 | uses: pypa/gh-action-pypi-publish@release/v1 52 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Test 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | - renovate/** 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.12", "3.13"] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Install uv 24 | uses: astral-sh/setup-uv@v6 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | enable-cache: true 28 | cache-dependency-glob: "requirements_dev.txt" 29 | activate-environment: true 30 | - name: Install dependencies 31 | run: | 32 | uv pip install -r requirements_dev.txt 33 | - name: Test with pytest 34 | run: | 35 | pytest --cov=google_nest_sdm --cov-report=term-missing 36 | - uses: codecov/codecov-action@v5.4.3 37 | with: 38 | token: ${{ secrets.CODECOV_TOKEN }} 39 | env_vars: OS,PYTHON 40 | fail_ci_if_error: true 41 | verbose: true 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | docs/ 153 | .DS_Store 154 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v5.0.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-added-large-files 10 | - repo: https://github.com/psf/black 11 | rev: 25.1.0 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/charliermarsh/ruff-pre-commit 15 | rev: v0.11.12 16 | hooks: 17 | - id: ruff 18 | args: 19 | - --fix 20 | - --exit-non-zero-on-fix 21 | - repo: local 22 | hooks: 23 | - id: mypy 24 | name: mypy 25 | entry: script/run-mypy.sh 26 | language: script 27 | types: [python] 28 | require_serial: true 29 | - repo: https://github.com/codespell-project/codespell 30 | rev: v2.4.1 31 | hooks: 32 | - id: codespell 33 | - repo: https://github.com/adrienverge/yamllint.git 34 | rev: v1.37.1 35 | hooks: 36 | - id: yamllint 37 | exclude: '^tests/tool/testdata/.*\.yaml$' 38 | args: 39 | - --strict 40 | - -c 41 | - ".yaml-lint.yaml" 42 | - repo: https://github.com/asottile/setup-cfg-fmt 43 | rev: v2.8.0 44 | hooks: 45 | - id: setup-cfg-fmt 46 | -------------------------------------------------------------------------------- /.ruff.toml: -------------------------------------------------------------------------------- 1 | target-version = "py310" 2 | 3 | [lint] 4 | ignore = ["E501"] 5 | -------------------------------------------------------------------------------- /.yaml-lint.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | ignore: | 3 | venv 4 | tests/testdata 5 | extends: default 6 | rules: 7 | truthy: 8 | allowed-values: ['true', 'false', 'on', 'yes'] 9 | comments: 10 | min-spaces-from-content: 1 11 | line-length: disable 12 | braces: 13 | min-spaces-inside: 0 14 | max-spaces-inside: 1 15 | brackets: 16 | min-spaces-inside: 0 17 | max-spaces-inside: 0 18 | indentation: 19 | spaces: 2 20 | indent-sequences: consistent 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # python-google-nest-sdm 2 | 3 | This is a library for Google Nest [Device Access](https://developers.google.com/nest/device-access) 4 | using the [Smart Device Management API](https://developers.google.com/nest/device-access/api). 5 | 6 | # Usage 7 | 8 | This can be used with the sandbox which requires [Registration](https://developers.google.com/nest/device-access/registration), accepting terms 9 | and a fee. 10 | 11 | You'll want to following the [Get Started](https://developers.google.com/nest/device-access/get-started) 12 | guides for setup including steps in the google cloud console. Overall, this is 13 | fairly complicated with many steps that are easy to get wrong. It is likely 14 | worth it to make sure you can get the API working using their supplied curl 15 | commands with your account before attempting to use this library. 16 | 17 | # Structure 18 | 19 | This API was designed for use in Home Assistant following the advice in 20 | [Building a Python Library for an API](https://developers.home-assistant.io/docs/api_lib_index/). 21 | 22 | If you are integrating this from outside Home Assistant, you'll need to 23 | create your own oauth integration and token refresh mechanism and tooling. 24 | 25 | # Fetching Data 26 | 27 | This is an example to use the command line tool to access the API: 28 | 29 | ``` 30 | PROJECT_ID="some-project-id" 31 | CLIENT_ID="some-client-id" 32 | CLIENT_SECRET="some-client-secret" 33 | # Initial call will ask you to authorize OAuth2 then cache the token 34 | google_nest --project_id="${PROJECT_ID}" --client_id="${CLIENT_ID}" --client_secret="${CLIENT_SECRET}" list_structures 35 | # Subsequent calls only need the project id 36 | google_nest --project_id="${PROJECT_ID}" get_device "some-device-id" 37 | google_nest --project_id="${PROJECT_ID}" set_mode COOL 38 | google_nest --project_id="${PROJECT_ID}" set_cool 25.0 39 | ``` 40 | 41 | # Subscriptions 42 | 43 | See [Device Access: Getting Started: Subscribe to Events](https://developers.google.com/nest/device-access/subscribe-to-events) 44 | for documentation on how to create a pull subscription. 45 | 46 | You can create the subscription to use with the tool with these steps: 47 | 48 | * Create the topic: 49 | * Visit the [Device Access Console](https://console.nest.google.com/device-access) 50 | * Select a project 51 | * Enable Pub/Sub and note the full `topic` based on the `project_id` 52 | * Create the subscriber: 53 | * Visit [Google Cloud Platform: Pub/Sub: Subscriptions](https://console.cloud.google.com/cloudpubsub/subscriptions) 54 | * Create a subscriber 55 | * Enter the `Topic Name` 56 | * Create a `Subscription Name`, e.g. "project-id-python" which is your `subscriber_id` 57 | 58 | This is an example to run the command line tool to subscribe: 59 | ``` 60 | PROJECT_ID="some-project-id" 61 | SUBSCRIPTION_ID="projects/some-id/subscriptions/enterprise-some-project-id-python-google-nest" 62 | google_nest --project_id="${PROJECT_ID}" subscribe ${SUBSCRIPTION_ID} 63 | ``` 64 | 65 | # Development 66 | 67 | ``` 68 | $ python3 -m venv venv 69 | $ source venv/bin/activate 70 | $ pip3 install -e . 71 | $ pip3 install -r requirements.txt 72 | 73 | # Running tests 74 | $ pytest 75 | 76 | # Running tests w/ Code Coverage 77 | $ pytest --cov=google_nest_sdm tests/ --cov-report=term-missing 78 | 79 | # Formatting and linting 80 | $ pre-commit run --all-files 81 | ``` 82 | 83 | # Funding and Support 84 | 85 | If you are interested in donating money to this effort, instead send a 86 | donation to [Black Girls Code](https://donorbox.org/support-black-girls-code) 87 | which is a great organization growing the next generation of software engineers. 88 | -------------------------------------------------------------------------------- /google_nest_sdm/__init__.py: -------------------------------------------------------------------------------- 1 | """Library for using the Google Nest SDM API. 2 | 3 | See https://developers.google.com/nest/device-access/api for the documentation 4 | on how to use the API. 5 | 6 | The primary components in this library are: 7 | - `auth`: You need to implement `AbstractAuth` to provide credentials. 8 | - `google_nest_subscriber`: A wrapper around the pub/sub system for efficiently 9 | listening to changes in device state. 10 | - `device_manager`: Holds local state for devices, populated by the subscriber. 11 | - `device`: Holds device traits and current device state 12 | - `event_media`: For media related to camera or doorbell events. 13 | 14 | Example usage: 15 | ``` 16 | subscriber = GoogleNestSubscriber( 17 | auth_impl, # Your credential provider 18 | # Follow nest developer API docs to obtain these 19 | DEVICE_ACCESS_PROJECT_ID, 20 | SUBSCRIBER_ID, 21 | ) 22 | unsub = await subscriber.start_async() 23 | device_manager = await subscriber.async_get_device_manager() 24 | 25 | for device in device_manager.devices.values(): 26 | if device.temperature: 27 | temp = device.temperatureambient_temperature_celsius 28 | print("Device temperature: {temp:0.2f}") 29 | 30 | unsub() # Unsubscribe when done 31 | ``` 32 | """ 33 | 34 | __all__ = [ 35 | "google_nest_subscriber", 36 | "device_manager", 37 | "device", 38 | "camera_traits", 39 | "device_traits", 40 | "doorbell_traits", 41 | "thermostat_traits", 42 | "structure", 43 | "auth", 44 | "event_media", 45 | "event", 46 | "exceptions", 47 | "diagnostics", 48 | ] 49 | -------------------------------------------------------------------------------- /google_nest_sdm/admin_client.py: -------------------------------------------------------------------------------- 1 | """Admin Client library for the Google Nest SDM API. 2 | 3 | This manages administrative tasks for setting up pubsub topics and subscriptions. 4 | 5 | This library exists to provide an asyncio interface given that the current pubsub 6 | clients are synchronous. 7 | """ 8 | 9 | import logging 10 | import re 11 | import asyncio 12 | from typing import Any 13 | from dataclasses import dataclass, field 14 | 15 | from .diagnostics import SUBSCRIBER_DIAGNOSTICS as DIAGNOSTICS 16 | from .auth import AbstractAuth 17 | from .exceptions import ( 18 | ApiException, 19 | NotFoundException, 20 | ApiForbiddenException, 21 | ConfigurationException, 22 | ) 23 | 24 | _LOGGER = logging.getLogger(__name__) 25 | 26 | __all__ = [ 27 | "AdminClient", 28 | "EligibleTopics", 29 | "EligibleSubscriptions", 30 | "validate_subscription_name", 31 | "validate_topic_name", 32 | "PUBSUB_API_HOST", 33 | ] 34 | 35 | PUBSUB_API_HOST = "https://pubsub.googleapis.com/v1" 36 | SDM_MANAGED_TOPIC_FORMAT = ( 37 | "projects/sdm-prod/topics/enterprise-{device_access_project_id}" 38 | ) 39 | 40 | # Used to catch invalid subscriber id 41 | EXPECTED_SUBSCRIBER_REGEXP = re.compile("^projects/[^/]+/subscriptions/[^/]+$") 42 | 43 | # Used to catch a topic misconfiguration 44 | EXPECTED_TOPIC_REGEXP = re.compile("^projects/[^/]+/topics/[^/]+$") 45 | 46 | # Topic prefix for the project 47 | EXPECTED_PROJECS_PREFIX = re.compile("^projects/[^/]+$") 48 | 49 | 50 | @dataclass 51 | class EligibleTopics: 52 | """Eligible topics for the project.""" 53 | 54 | topic_names: list[str] = field(default_factory=list) 55 | 56 | 57 | @dataclass 58 | class EligibleSubscriptions: 59 | """Eligible topics for the project.""" 60 | 61 | subscription_names: list[str] = field(default_factory=list) 62 | 63 | # Policy that gives Device Access Console permission to publish to a topic 64 | DEFAULT_TOPIC_IAM_POLICY = { 65 | "bindings": [ 66 | { 67 | "members": [ 68 | "group:sdm-publisher@googlegroups.com" 69 | ], 70 | "role": "roles/pubsub.publisher" 71 | } 72 | ] 73 | } 74 | 75 | 76 | def validate_subscription_name(subscription_name: str) -> None: 77 | """Validates that a subscription name is correct. 78 | 79 | Raises ConfigurationException on failure. 80 | """ 81 | if not EXPECTED_SUBSCRIBER_REGEXP.match(subscription_name): 82 | DIAGNOSTICS.increment("subscription_name_invalid") 83 | _LOGGER.debug("Subscription name did not match pattern: %s", subscription_name) 84 | raise ConfigurationException( 85 | "Subscription misconfigured. Expected subscriber_id to " 86 | f"match '{EXPECTED_SUBSCRIBER_REGEXP.pattern}' but was " 87 | f"'{subscription_name}'" 88 | ) 89 | 90 | 91 | def validate_topic_name(topic_name: str) -> None: 92 | """Validates that a topic name is correct. 93 | 94 | Raises ConfigurationException on failure. 95 | """ 96 | if not EXPECTED_TOPIC_REGEXP.match(topic_name): 97 | DIAGNOSTICS.increment("topic_name_invalid") 98 | _LOGGER.debug("Topic name did not match pattern: %s", topic_name) 99 | raise ConfigurationException( 100 | "Subscription misconfigured. Expected topic name to " 101 | f"match '{EXPECTED_TOPIC_REGEXP.pattern}' but was " 102 | f"'{topic_name}'." 103 | ) 104 | 105 | 106 | def validate_projects_prefix(project_path: str) -> None: 107 | """Validates that a topic or subscription prefix is correct. 108 | 109 | Raises ConfigurationException on failure. 110 | """ 111 | if not EXPECTED_PROJECS_PREFIX.match(project_path): 112 | DIAGNOSTICS.increment("topic_prefix_invalid") 113 | _LOGGER.debug("Topic prefix did not match pattern: %s", project_path) 114 | raise ConfigurationException( 115 | "Subscription misconfigured. Expected topic name to " 116 | f"match '{EXPECTED_PROJECS_PREFIX.pattern}' but was " 117 | f"'{project_path}'." 118 | ) 119 | 120 | 121 | class AdminClient: 122 | """Admin client for the Google Nest SDM API.""" 123 | 124 | def __init__( 125 | self, 126 | auth: AbstractAuth, 127 | cloud_project_id: str, 128 | ) -> None: 129 | """Initialize the admin client. 130 | 131 | The auth instance must be configured with the correct host (PUBSUB_API_HOST). 132 | """ 133 | self._cloud_project_id = cloud_project_id 134 | self._auth = auth 135 | 136 | async def create_topic(self, topic_name: str) -> None: 137 | """Create a pubsub topic for the project.""" 138 | validate_topic_name(topic_name) 139 | await self._auth.put(topic_name) 140 | 141 | async def delete_topic(self, topic_name: str) -> None: 142 | """Delete a pubsub topic for the project.""" 143 | validate_topic_name(topic_name) 144 | await self._auth.delete(topic_name) 145 | 146 | async def list_topics(self, projects_prefix: str) -> list[str]: 147 | """List the pubsub topics for the project. 148 | 149 | The topic prefix should be in the format `projects/{console_project_id}`. 150 | """ 151 | validate_projects_prefix(projects_prefix) 152 | response = await self._auth.get_json(f"{projects_prefix}/topics") 153 | return [topic["name"] for topic in response.get("topics", ())] 154 | 155 | async def get_topic(self, topic_name: str) -> dict[str, Any]: 156 | """Get a pubsub topic for the project.""" 157 | validate_topic_name(topic_name) 158 | return await self._auth.get_json(topic_name) 159 | 160 | async def set_topic_iam_policy(self, topic_name: str, policy: dict[str, Any]) -> None: 161 | """Create a pubsub topic for the project.""" 162 | validate_topic_name(topic_name) 163 | path = f"{topic_name}:setIamPolicy" 164 | await self._auth.post( 165 | path, 166 | json={"policy":policy} 167 | ) 168 | 169 | 170 | async def create_subscription( 171 | self, topic_name: str, subscription_name: str 172 | ) -> None: 173 | """Create a pubsub subscription for the project.""" 174 | validate_topic_name(topic_name) 175 | validate_subscription_name(subscription_name) 176 | body = {"topic": topic_name} 177 | await self._auth.put(subscription_name, json=body) 178 | 179 | async def delete_subscription(self, subscription_name: str) -> None: 180 | """Delete a pubsub subscription for the project.""" 181 | validate_subscription_name(subscription_name) 182 | await self._auth.delete(subscription_name) 183 | 184 | async def list_subscriptions(self, projects_prefix: str) -> list[dict[str, Any]]: 185 | """List the pubsub subscriptions for the project. 186 | The projects_prefix should be in the format `projects/{console_project_id}`. 187 | """ 188 | validate_projects_prefix(projects_prefix) 189 | response = await self._auth.get_json(f"{projects_prefix}/subscriptions") 190 | return response.get("subscriptions", []) # type: ignore[no-any-return] 191 | 192 | async def list_eligible_topics( 193 | self, device_access_project_id: str 194 | ) -> EligibleTopics: 195 | """List the eligible topics for the project. 196 | 197 | This will try to find any topics already created for the project by either 198 | the device access console or by the user. 199 | """ 200 | 201 | sdm_topic_name = SDM_MANAGED_TOPIC_FORMAT.format( 202 | device_access_project_id=device_access_project_id 203 | ) 204 | 205 | async def get_sdm_topic() -> str | None: 206 | try: 207 | await self.get_topic(sdm_topic_name) 208 | except ApiForbiddenException: 209 | _LOGGER.debug( 210 | "SDM topic exists but we do not have permission to access it (expected)" 211 | ) 212 | # The SDM topic exists. It is normal that we do not have permission 213 | # to access it. 214 | return sdm_topic_name 215 | except NotFoundException: 216 | _LOGGER.debug( 217 | "SDM topic does not exist, proceeding to check cloud projects" 218 | ) 219 | return None 220 | except ApiException as err: 221 | _LOGGER.info( 222 | "Unexpected error retrieving an SDM created topic: %s", err 223 | ) 224 | raise ApiException("Error retrieving SDM created topic") from err 225 | _LOGGER.debug( 226 | "SDM topic exists and we have permission to access it (unexpected)" 227 | ) 228 | return sdm_topic_name 229 | 230 | async def get_cloud_topics() -> list[str]: 231 | try: 232 | return await self.list_topics(f"projects/{self._cloud_project_id}") 233 | except ApiException as err: 234 | _LOGGER.info("Unexpected error listing topics: %s", err) 235 | raise ApiException( 236 | "Error while listing existing cloud console topics" 237 | ) from err 238 | 239 | (sdm_topic_task, cloud_topics_task) = await asyncio.gather( 240 | get_sdm_topic(), get_cloud_topics() 241 | ) 242 | topics = [] 243 | if sdm_topic_task: 244 | topics.append(sdm_topic_task) 245 | topics.extend(cloud_topics_task) 246 | return EligibleTopics(topic_names=topics) 247 | 248 | async def list_eligible_subscriptions( 249 | self, expected_topic_name: str 250 | ) -> EligibleSubscriptions: 251 | """Return a set of eligible subscriptions for the project.""" 252 | subscriptions = await self.list_subscriptions( 253 | f"projects/{self._cloud_project_id}" 254 | ) 255 | return EligibleSubscriptions( 256 | subscription_names=[ 257 | sub["name"] 258 | for sub in subscriptions 259 | if sub["topic"] == expected_topic_name 260 | ] 261 | ) 262 | -------------------------------------------------------------------------------- /google_nest_sdm/auth.py: -------------------------------------------------------------------------------- 1 | """Authentication library, implemented by users of the API. 2 | 3 | This library is a simple `aiohttp` that handles authentication when talking 4 | to the API. Users are expected to provide their own implementation that provides 5 | credentials obtained using the standard Google authentication approaches 6 | described at https://developers.google.com/nest/device-access/api/authorization 7 | 8 | An implementation of `AbstractAuth` implements `async_get_access_token` 9 | to provide authentication credentials to the SDM library. The implementation is 10 | responsible for managing the lifecycle of the token (any persistence needed, 11 | or refresh to deal with expiration, etc). 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | import logging 17 | from abc import ABC, abstractmethod 18 | from dataclasses import dataclass, field 19 | from asyncio import TimeoutError 20 | from typing import Any 21 | from http import HTTPStatus 22 | 23 | import aiohttp 24 | from aiohttp.client_exceptions import ClientError 25 | from google.auth.credentials import Credentials 26 | from google.oauth2.credentials import Credentials as OAuthCredentials 27 | from mashumaro.mixins.json import DataClassJSONMixin 28 | 29 | from .exceptions import ( 30 | ApiException, 31 | AuthException, 32 | ApiForbiddenException, 33 | NotFoundException, 34 | ) 35 | 36 | _LOGGER = logging.getLogger(__name__) 37 | 38 | __all__ = ["AbstractAuth"] 39 | 40 | HTTP_UNAUTHORIZED = 401 41 | AUTHORIZATION_HEADER = "Authorization" 42 | ERROR = "error" 43 | STATUS = "status" 44 | MESSAGE = "message" 45 | 46 | 47 | @dataclass 48 | class Status(DataClassJSONMixin): 49 | """Status of the media item.""" 50 | 51 | code: int = field(default=HTTPStatus.OK) 52 | """The status code, which should be an enum value of google.rpc.Code""" 53 | 54 | message: str | None = None 55 | """A developer-facing error message, which should be in English""" 56 | 57 | details: list[dict[str, Any]] = field(default_factory=list) 58 | """A list of messages that carry the error details""" 59 | 60 | 61 | @dataclass 62 | class Error: 63 | """Error details from the API response.""" 64 | 65 | status: str | None = None 66 | code: int | None = None 67 | message: str | None = None 68 | details: list[dict[str, Any]] | None = field(default_factory=list) 69 | 70 | def __str__(self) -> str: 71 | """Return a string representation of the error details.""" 72 | error_message = "" 73 | if self.status: 74 | error_message += self.status 75 | if self.code: 76 | if error_message: 77 | error_message += f" ({self.code})" 78 | else: 79 | error_message += str(self.code) 80 | if self.message: 81 | if error_message: 82 | error_message += ": " 83 | error_message += self.message 84 | if self.details: 85 | error_message += f"\nError details: ({self.details})" 86 | return error_message 87 | 88 | 89 | @dataclass 90 | class ErrorResponse(DataClassJSONMixin): 91 | """A response message that contains an error message.""" 92 | 93 | error: Error | None = None 94 | 95 | 96 | class AbstractAuth(ABC): 97 | """Abstract class to make authenticated requests.""" 98 | 99 | def __init__(self, websession: aiohttp.ClientSession, host: str): 100 | """Initialize the AbstractAuth.""" 101 | self._websession = websession 102 | self._host = host 103 | 104 | @abstractmethod 105 | async def async_get_access_token(self) -> str: 106 | """Return a valid access token.""" 107 | 108 | async def async_get_creds(self) -> Credentials: 109 | """Return creds for subscriber API.""" 110 | token = await self.async_get_access_token() 111 | return OAuthCredentials(token=token) 112 | 113 | async def request( 114 | self, 115 | method: str, 116 | url: str, 117 | **kwargs: Any, 118 | ) -> aiohttp.ClientResponse: 119 | """Make a request.""" 120 | headers = kwargs.get("headers") 121 | 122 | if headers is None: 123 | headers = {} 124 | else: 125 | headers = dict(headers) 126 | del kwargs["headers"] 127 | if AUTHORIZATION_HEADER not in headers: 128 | try: 129 | access_token = await self.async_get_access_token() 130 | except TimeoutError as err: 131 | raise ApiException(f"Timeout requesting API token: {err}") from err 132 | except ClientError as err: 133 | raise AuthException(f"Access token failure: {err}") from err 134 | headers[AUTHORIZATION_HEADER] = f"Bearer {access_token}" 135 | if not (url.startswith("http://") or url.startswith("https://")): 136 | url = f"{self._host}/{url}" 137 | _LOGGER.debug("request[%s]=%s", method, url) 138 | if method == "post" and "json" in kwargs: 139 | _LOGGER.debug("request[post json]=%s", kwargs["json"]) 140 | try: 141 | return await self._request(method, url, headers=headers, **kwargs) 142 | except (ClientError, TimeoutError) as err: 143 | raise ApiException(f"Error connecting to API: {err}") from err 144 | 145 | async def _request( 146 | self, method: str, url: str, headers: dict[str, str], **kwargs: Any 147 | ) -> aiohttp.ClientResponse: 148 | return await self._websession.request(method, url, **kwargs, headers=headers) 149 | 150 | async def get(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse: 151 | """Make a get request.""" 152 | response = await self.request("get", url, **kwargs) 153 | return await AbstractAuth._raise_for_status(response) 154 | 155 | async def get_json(self, url: str, **kwargs: Any) -> dict[str, Any]: 156 | """Make a get request and return json response.""" 157 | resp = await self.get(url, **kwargs) 158 | try: 159 | result = await resp.json() 160 | except ClientError as err: 161 | raise ApiException("Server returned malformed response") from err 162 | if not isinstance(result, dict): 163 | raise ApiException("Server return malformed response: %s" % result) 164 | _LOGGER.debug("response=%s", result) 165 | return result 166 | 167 | async def post(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse: 168 | """Make a post request.""" 169 | response = await self.request("post", url, **kwargs) 170 | return await AbstractAuth._raise_for_status(response) 171 | 172 | async def post_json(self, url: str, **kwargs: Any) -> dict[str, Any]: 173 | """Make a post request and return a json response.""" 174 | resp = await self.post(url, **kwargs) 175 | try: 176 | result = await resp.json() 177 | except ClientError as err: 178 | raise ApiException("Server returned malformed response") from err 179 | if not isinstance(result, dict): 180 | raise ApiException("Server returned malformed response: %s" % result) 181 | _LOGGER.debug("response=%s", result) 182 | return result 183 | 184 | async def put(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse: 185 | """Make a put request.""" 186 | response = await self.request("put", url, **kwargs) 187 | return await AbstractAuth._raise_for_status(response) 188 | 189 | async def delete(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse: 190 | """Make a delete request.""" 191 | response = await self.request("delete", url, **kwargs) 192 | return await AbstractAuth._raise_for_status(response) 193 | 194 | @classmethod 195 | async def _raise_for_status( 196 | cls, resp: aiohttp.ClientResponse 197 | ) -> aiohttp.ClientResponse: 198 | """Raise exceptions on failure methods.""" 199 | error_detail = await cls._error_detail(resp) 200 | try: 201 | resp.raise_for_status() 202 | except aiohttp.ClientResponseError as err: 203 | error_message = f"{err.message} response from API ({resp.status})" 204 | if error_detail: 205 | error_message += f": {error_detail}" 206 | if err.status == HTTPStatus.FORBIDDEN: 207 | raise ApiForbiddenException(error_message) 208 | if err.status == HTTPStatus.UNAUTHORIZED: 209 | raise AuthException(error_message) 210 | if err.status == HTTPStatus.NOT_FOUND: 211 | raise NotFoundException(error_message) 212 | raise ApiException(error_message) from err 213 | except aiohttp.ClientError as err: 214 | raise ApiException(f"Error from API: {err}") from err 215 | return resp 216 | 217 | @classmethod 218 | async def _error_detail(cls, resp: aiohttp.ClientResponse) -> Error | None: 219 | """Returns an error message string from the APi response.""" 220 | if resp.status < 400: 221 | return None 222 | try: 223 | result = await resp.text() 224 | except ClientError: 225 | return None 226 | try: 227 | error_response = ErrorResponse.from_json(result) 228 | except (LookupError, ValueError): 229 | return None 230 | return error_response.error 231 | -------------------------------------------------------------------------------- /google_nest_sdm/camera_traits.py: -------------------------------------------------------------------------------- 1 | """Traits belonging to camera devices.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass, field 7 | import datetime 8 | from enum import Enum 9 | import logging 10 | from typing import ClassVar 11 | import urllib.parse as urlparse 12 | 13 | from mashumaro import DataClassDictMixin, field_options 14 | from mashumaro.config import BaseConfig 15 | from mashumaro.types import SerializationStrategy 16 | 17 | from .event import ( 18 | CameraClipPreviewEvent, 19 | CameraMotionEvent, 20 | CameraPersonEvent, 21 | CameraSoundEvent, 22 | EventImageContentType, 23 | EventImageType, 24 | EventType, 25 | ) 26 | from .traits import CommandDataClass, TraitType 27 | from .webrtc_util import fix_mozilla_sdp_answer 28 | 29 | __all__ = [ 30 | "CameraImageTrait", 31 | "CameraLiveStreamTrait", 32 | "CameraEventImageTrait", 33 | "CameraMotionTrait", 34 | "CameraPersonTrait", 35 | "CameraSoundTrait", 36 | "CameraClipPreviewTrait", 37 | "Resolution", 38 | "Stream", 39 | "StreamUrls", 40 | "RtspStream", 41 | "WebRtcStream", 42 | "StreamingProtocol", 43 | "EventImage", 44 | ] 45 | 46 | _LOGGER = logging.getLogger(__name__) 47 | 48 | MAX_IMAGE_RESOLUTION = "maxImageResolution" 49 | MAX_VIDEO_RESOLUTION = "maxVideoResolution" 50 | WIDTH = "width" 51 | HEIGHT = "height" 52 | VIDEO_CODECS = "videoCodecs" 53 | AUDIO_CODECS = "audioCodecs" 54 | SUPPORTED_PROTOCOLS = "supportedProtocols" 55 | STREAM_URLS = "streamUrls" 56 | RESULTS = "results" 57 | RTSP_URL = "rtspUrl" 58 | STREAM_EXTENSION_TOKEN = "streamExtensionToken" 59 | STREAM_TOKEN = "streamToken" 60 | URL = "url" 61 | TOKEN = "token" 62 | ANSWER_SDP = "answerSdp" 63 | MEDIA_SESSION_ID = "mediaSessionId" 64 | 65 | EVENT_IMAGE_CLIP_PREVIEW = "clip_preview" 66 | 67 | 68 | @dataclass 69 | class Resolution: 70 | """Maximum Resolution of an image or stream.""" 71 | 72 | width: int | None = None 73 | height: int | None = None 74 | 75 | 76 | @dataclass 77 | class CameraImageTrait(DataClassDictMixin): 78 | """This trait belongs to any device that supports taking images.""" 79 | 80 | NAME: ClassVar[TraitType] = TraitType.CAMERA_IMAGE 81 | 82 | max_image_resolution: Resolution | None = field( 83 | metadata=field_options(alias="maxImageResolution"), default=None 84 | ) 85 | """Maximum resolution of the camera image.""" 86 | 87 | 88 | @dataclass 89 | class Stream(DataClassDictMixin, CommandDataClass, ABC): 90 | """Base class for streams.""" 91 | 92 | expires_at: datetime.datetime = field(metadata=field_options(alias="expiresAt")) 93 | """Time at which both streamExtensionToken and streamToken expire.""" 94 | 95 | @abstractmethod 96 | async def extend_stream(self) -> Stream: 97 | """Extend the lifetime of the stream.""" 98 | 99 | @abstractmethod 100 | async def stop_stream(self) -> None: 101 | """Invalidate the stream.""" 102 | 103 | 104 | @dataclass 105 | class StreamUrls: 106 | """Response object for stream urls""" 107 | 108 | rtsp_url: str = field(metadata=field_options(alias="rtspUrl")) 109 | """RTSP live stream URL.""" 110 | 111 | 112 | @dataclass 113 | class RtspStream(Stream): 114 | """Provides access an RTSP live stream URL.""" 115 | 116 | stream_urls: StreamUrls = field(metadata=field_options(alias="streamUrls")) 117 | """Stream urls to access the live stream.""" 118 | 119 | stream_token: str = field(metadata=field_options(alias="streamToken")) 120 | """Token to use to access an RTSP live stream.""" 121 | 122 | stream_extension_token: str = field( 123 | metadata=field_options(alias="streamExtensionToken") 124 | ) 125 | """Token to use to extend access to an RTSP live stream.""" 126 | 127 | @property 128 | def rtsp_stream_url(self) -> str: 129 | """RTSP live stream URL.""" 130 | return self.stream_urls.rtsp_url 131 | 132 | async def extend_stream(self) -> Stream | RtspStream: 133 | """Extend the lifetime of the stream.""" 134 | return await self.extend_rtsp_stream() 135 | 136 | async def extend_rtsp_stream(self) -> RtspStream: 137 | """Request a new RTSP live stream URL access token.""" 138 | data = { 139 | "command": "sdm.devices.commands.CameraLiveStream.ExtendRtspStream", 140 | "params": {"streamExtensionToken": self.stream_extension_token}, 141 | } 142 | response_data = await self.cmd.execute_json(data) 143 | results = response_data[RESULTS] 144 | # Update the stream url with the new token 145 | stream_token = results[STREAM_TOKEN] 146 | parsed = urlparse.urlparse(self.rtsp_stream_url) 147 | parsed = parsed._replace(query=f"auth={stream_token}") 148 | url = urlparse.urlunparse(parsed) 149 | results[STREAM_URLS] = {} 150 | results[STREAM_URLS][RTSP_URL] = url 151 | obj = RtspStream.from_dict(results) 152 | obj._cmd = self.cmd 153 | return obj 154 | 155 | async def stop_stream(self) -> None: 156 | """Invalidate the stream.""" 157 | return await self.stop_rtsp_stream() 158 | 159 | async def stop_rtsp_stream(self) -> None: 160 | """Invalidates a valid RTSP access token and stops the RTSP live stream.""" 161 | data = { 162 | "command": "sdm.devices.commands.CameraLiveStream.StopRtspStream", 163 | "params": {"streamExtensionToken": self.stream_extension_token}, 164 | } 165 | await self.cmd.execute(data) 166 | 167 | 168 | @dataclass 169 | class WebRtcStream(Stream): 170 | """Provides access an RTSP live stream URL.""" 171 | 172 | answer_sdp: str = field(metadata=field_options(alias="answerSdp")) 173 | """An SDP answer to use with the local device displaying the stream.""" 174 | 175 | media_session_id: str = field(metadata=field_options(alias="mediaSessionId")) 176 | """Media Session ID of the live stream.""" 177 | 178 | async def extend_stream(self) -> WebRtcStream: 179 | """Request a new RTSP live stream URL access token.""" 180 | data = { 181 | "command": "sdm.devices.commands.CameraLiveStream.ExtendWebRtcStream", 182 | "params": {MEDIA_SESSION_ID: self.media_session_id}, 183 | } 184 | response_data = await self.cmd.execute_json(data) 185 | # Preserve original answerSdp, and merge with response that contains 186 | # the other fields (expiresAt, and mediaSessionId. 187 | results = response_data[RESULTS] 188 | results[ANSWER_SDP] = self.answer_sdp 189 | obj = WebRtcStream.from_dict(results) 190 | obj._cmd = self.cmd 191 | return obj 192 | 193 | async def stop_stream(self) -> None: 194 | """Invalidates a valid RTSP access token and stops the RTSP live stream.""" 195 | data = { 196 | "command": "sdm.devices.commands.CameraLiveStream.StopWebRtcStream", 197 | "params": {MEDIA_SESSION_ID: self.media_session_id}, 198 | } 199 | await self.cmd.execute(data) 200 | 201 | 202 | class StreamingProtocol(str, Enum): 203 | """Streaming protocols supported by the device.""" 204 | 205 | RTSP = "RTSP" 206 | WEB_RTC = "WEB_RTC" 207 | 208 | 209 | def _default_streaming_protocol() -> list[StreamingProtocol]: 210 | return [ 211 | StreamingProtocol.RTSP, 212 | ] 213 | 214 | 215 | class StreamingProtocolSerializationStrategy( 216 | SerializationStrategy, use_annotations=True 217 | ): 218 | """Parser for streaming protocols that ignores invalid values.""" 219 | 220 | def serialize(self, value: list[StreamingProtocol]) -> list[str]: 221 | return [str(x.name) for x in value] 222 | 223 | def deserialize(self, value: list[str]) -> list[StreamingProtocol]: 224 | return [ 225 | StreamingProtocol[x] for x in value if x in StreamingProtocol.__members__ 226 | ] or _default_streaming_protocol() 227 | 228 | 229 | @dataclass 230 | class CameraLiveStreamTrait(DataClassDictMixin, CommandDataClass): 231 | """This trait belongs to any device that supports live streaming.""" 232 | 233 | NAME: ClassVar[TraitType] = TraitType.CAMERA_LIVE_STREAM 234 | 235 | max_video_resolution: Resolution = field( 236 | metadata=field_options(alias="maxVideoResolution"), default_factory=Resolution 237 | ) 238 | """Maximum resolution of the video live stream.""" 239 | 240 | video_codecs: list[str] = field( 241 | metadata=field_options(alias="videoCodecs"), default_factory=list 242 | ) 243 | """Video codecs supported for the live stream.""" 244 | 245 | audio_codecs: list[str] = field( 246 | metadata=field_options(alias="audioCodecs"), default_factory=list 247 | ) 248 | """Audio codecs supported for the live stream.""" 249 | 250 | supported_protocols: list[StreamingProtocol] = field( 251 | metadata=field_options(alias="supportedProtocols"), 252 | default_factory=_default_streaming_protocol, 253 | ) 254 | """Streaming protocols supported for the live stream.""" 255 | 256 | async def generate_rtsp_stream(self) -> RtspStream: 257 | """Request a token to access an RTSP live stream URL.""" 258 | if StreamingProtocol.RTSP not in self.supported_protocols: 259 | raise ValueError("Device does not support RTSP stream") 260 | data = { 261 | "command": "sdm.devices.commands.CameraLiveStream.GenerateRtspStream", 262 | "params": {}, 263 | } 264 | response_data = await self.cmd.execute_json(data) 265 | results = response_data[RESULTS] 266 | obj = RtspStream.from_dict(results) 267 | obj._cmd = self.cmd 268 | return obj 269 | 270 | async def generate_web_rtc_stream(self, offer_sdp: str) -> WebRtcStream: 271 | """Request a token to access a Web RTC live stream URL.""" 272 | if StreamingProtocol.WEB_RTC not in self.supported_protocols: 273 | raise ValueError("Device does not support WEB_RTC stream") 274 | data = { 275 | "command": "sdm.devices.commands.CameraLiveStream.GenerateWebRtcStream", 276 | "params": {"offerSdp": offer_sdp}, 277 | } 278 | response_data = await self.cmd.execute_json(data) 279 | results = response_data[RESULTS] 280 | obj = WebRtcStream.from_dict(results) 281 | obj._cmd = self.cmd 282 | _LOGGER.debug("Received answer_sdp: %s", obj.answer_sdp) 283 | obj.answer_sdp = fix_mozilla_sdp_answer(offer_sdp, obj.answer_sdp) 284 | _LOGGER.debug("Return answer_sdp: %s", obj.answer_sdp) 285 | return obj 286 | 287 | class Config(BaseConfig): 288 | serialization_strategy = { 289 | list[StreamingProtocol]: StreamingProtocolSerializationStrategy(), 290 | } 291 | serialize_by_alias = True 292 | 293 | 294 | @dataclass 295 | class EventImage(DataClassDictMixin, CommandDataClass): 296 | """Provides access to an image in response to an event. 297 | 298 | Use a ?width or ?height query parameters to customize the resolution 299 | of the downloaded image. Only one of these parameters need to specified. 300 | The other parameter is scaled automatically according to the camera's 301 | aspect ratio. 302 | 303 | The token should be added as an HTTP header: 304 | Authorization: Basic 305 | """ 306 | 307 | event_image_type: EventImageContentType 308 | """Return the type of event image.""" 309 | 310 | url: str | None = field(default=None) 311 | """URL to download the camera image from.""" 312 | 313 | token: str | None = field(default=None) 314 | """Token to use in the HTTP Authorization header when downloading.""" 315 | 316 | async def contents( 317 | self, 318 | width: int | None = None, 319 | height: int | None = None, 320 | ) -> bytes: 321 | """Download the image bytes.""" 322 | if width: 323 | fetch_url = f"{self.url}?width={width}" 324 | elif height: 325 | fetch_url = f"{self.url}?width={height}" 326 | else: 327 | assert self.url 328 | fetch_url = self.url 329 | return await self.cmd.fetch_image(fetch_url, basic_auth=self.token) 330 | 331 | 332 | @dataclass 333 | class CameraEventImageTrait(DataClassDictMixin, CommandDataClass): 334 | """This trait belongs to any device that generates images from events.""" 335 | 336 | NAME: ClassVar[TraitType] = TraitType.CAMERA_EVENT_IMAGE 337 | 338 | async def generate_image(self, event_id: str) -> EventImage: 339 | """Provide a URL to download a camera image.""" 340 | data = { 341 | "command": "sdm.devices.commands.CameraEventImage.GenerateImage", 342 | "params": { 343 | "eventId": event_id, 344 | }, 345 | } 346 | response_data = await self.cmd.execute_json(data) 347 | results = response_data[RESULTS] 348 | img = EventImage(**results, event_image_type=EventImageType.IMAGE) 349 | img._cmd = self.cmd 350 | return img 351 | 352 | 353 | @dataclass 354 | class CameraMotionTrait: 355 | """For any device that supports motion detection events.""" 356 | 357 | NAME: ClassVar[TraitType] = TraitType.CAMERA_MOTION 358 | EVENT_NAME: ClassVar[EventType] = CameraMotionEvent.NAME 359 | 360 | 361 | @dataclass 362 | class CameraPersonTrait: 363 | """For any device that supports person detection events.""" 364 | 365 | NAME: ClassVar[TraitType] = TraitType.CAMERA_PERSON 366 | EVENT_NAME: ClassVar[EventType] = CameraPersonEvent.NAME 367 | 368 | 369 | @dataclass 370 | class CameraSoundTrait: 371 | """For any device that supports sound detection events.""" 372 | 373 | NAME: ClassVar[TraitType] = TraitType.CAMERA_SOUND 374 | EVENT_NAME: ClassVar[EventType] = CameraSoundEvent.NAME 375 | 376 | 377 | @dataclass 378 | class CameraClipPreviewTrait(DataClassDictMixin, CommandDataClass): 379 | """For any device that supports a clip preview.""" 380 | 381 | NAME: ClassVar[TraitType] = TraitType.CAMERA_CLIP_PREVIEW 382 | EVENT_NAME: ClassVar[EventType] = CameraClipPreviewEvent.NAME 383 | 384 | async def generate_event_image(self, preview_url: str) -> EventImage | None: 385 | """Provide a URL to download a camera image from the active event.""" 386 | img = EventImage(url=preview_url, event_image_type=EventImageType.CLIP_PREVIEW) 387 | img._cmd = self.cmd 388 | return img 389 | -------------------------------------------------------------------------------- /google_nest_sdm/device.py: -------------------------------------------------------------------------------- 1 | """A device from the Smart Device Management API.""" 2 | 3 | from __future__ import annotations 4 | 5 | import datetime 6 | import logging 7 | from typing import Any, Awaitable, Callable 8 | from dataclasses import dataclass, field, fields, asdict 9 | 10 | from mashumaro import field_options, DataClassDictMixin 11 | from mashumaro.config import BaseConfig 12 | from mashumaro.types import SerializationStrategy 13 | 14 | from . import camera_traits, device_traits, doorbell_traits, thermostat_traits 15 | from .auth import AbstractAuth 16 | from .doorbell_traits import DoorbellChimeTrait 17 | from .diagnostics import Diagnostics, redact_data 18 | from .event import EventMessage, EventProcessingError 19 | from .event_media import EventMediaManager 20 | from .traits import Command 21 | from .model import TraitDataClass, SDM_PREFIX, TRAITS 22 | 23 | _LOGGER = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class ParentRelation(DataClassDictMixin): 28 | """Represents the parent structure/room of the current resource.""" 29 | 30 | parent: str 31 | display_name: str = field(metadata=field_options(alias="displayName")) 32 | 33 | class Config(BaseConfig): 34 | serialize_by_alias = True 35 | 36 | 37 | @dataclass 38 | class TraitTypes(TraitDataClass): 39 | """Data model for parsing traits in the Google Nest SDM API.""" 40 | 41 | # Device Traits 42 | connectivity: device_traits.ConnectivityTrait | None = field( 43 | metadata=field_options( 44 | alias="sdm.devices.traits.Connectivity", 45 | ), 46 | default=None, 47 | ) 48 | fan: device_traits.FanTrait | None = field( 49 | metadata=field_options( 50 | alias="sdm.devices.traits.Fan", 51 | ), 52 | default=None, 53 | ) 54 | info: device_traits.InfoTrait | None = field( 55 | metadata=field_options( 56 | alias="sdm.devices.traits.Info", 57 | ), 58 | default=None, 59 | ) 60 | humidity: device_traits.HumidityTrait | None = field( 61 | metadata=field_options( 62 | alias="sdm.devices.traits.Humidity", 63 | ), 64 | default=None, 65 | ) 66 | temperature: device_traits.TemperatureTrait | None = field( 67 | metadata=field_options( 68 | alias="sdm.devices.traits.Temperature", 69 | ), 70 | default=None, 71 | ) 72 | 73 | # Thermostat Traits 74 | thermostat_eco: thermostat_traits.ThermostatEcoTrait | None = field( 75 | metadata=field_options( 76 | alias="sdm.devices.traits.ThermostatEco", 77 | ), 78 | default=None, 79 | ) 80 | thermostat_hvac: thermostat_traits.ThermostatHvacTrait | None = field( 81 | metadata=field_options( 82 | alias="sdm.devices.traits.ThermostatHvac", 83 | ), 84 | default=None, 85 | ) 86 | thermostat_mode: thermostat_traits.ThermostatModeTrait | None = field( 87 | metadata=field_options( 88 | alias="sdm.devices.traits.ThermostatMode", 89 | ), 90 | default=None, 91 | ) 92 | thermostat_temperature_setpoint: ( 93 | thermostat_traits.ThermostatTemperatureSetpointTrait | None 94 | ) = field( 95 | metadata=field_options( 96 | alias="sdm.devices.traits.ThermostatTemperatureSetpoint", 97 | ), 98 | default=None, 99 | ) 100 | 101 | # # Camera Traits 102 | camera_image: camera_traits.CameraImageTrait | None = field( 103 | metadata=field_options( 104 | alias="sdm.devices.traits.CameraImage", 105 | ), 106 | default=None, 107 | ) 108 | camera_live_stream: camera_traits.CameraLiveStreamTrait | None = field( 109 | metadata=field_options(alias="sdm.devices.traits.CameraLiveStream"), 110 | default=None, 111 | ) 112 | camera_event_image: camera_traits.CameraEventImageTrait | None = field( 113 | metadata=field_options( 114 | alias="sdm.devices.traits.CameraEventImage", 115 | ), 116 | default=None, 117 | ) 118 | camera_motion: camera_traits.CameraMotionTrait | None = field( 119 | metadata=field_options( 120 | alias="sdm.devices.traits.CameraMotion", 121 | ), 122 | default=None, 123 | ) 124 | camera_person: camera_traits.CameraPersonTrait | None = field( 125 | metadata=field_options( 126 | alias="sdm.devices.traits.CameraPerson", 127 | ), 128 | default=None, 129 | ) 130 | camera_sound: camera_traits.CameraSoundTrait | None = field( 131 | metadata=field_options( 132 | alias="sdm.devices.traits.CameraSound", 133 | ), 134 | default=None, 135 | ) 136 | camera_clip_preview: camera_traits.CameraClipPreviewTrait | None = field( 137 | metadata=field_options( 138 | alias="sdm.devices.traits.CameraClipPreview", 139 | ), 140 | default=None, 141 | ) 142 | 143 | # # Doorbell Traits 144 | doorbell_chime: doorbell_traits.DoorbellChimeTrait | None = field( 145 | metadata=field_options( 146 | alias="sdm.devices.traits.DoorbellChime", 147 | ), 148 | default=None, 149 | ) 150 | 151 | 152 | class ParentRelationsSerializationStrategy(SerializationStrategy, use_annotations=True): 153 | """Parser to ignore invalid parent relations.""" 154 | 155 | def serialize(self, value: list[ParentRelation]) -> list[dict[str, Any]]: 156 | return [x.to_dict() for x in value] 157 | 158 | def deserialize(self, value: list[dict[str, Any]]) -> list[ParentRelation]: 159 | return [ 160 | ParentRelation.from_dict(relation) 161 | for relation in value 162 | if "parent" in relation and "displayName" in relation 163 | ] 164 | 165 | 166 | def _name_required() -> str: 167 | """Raise an error if the name field is not provided. 168 | 169 | This is a workaround for the fact that dataclasses children can't have 170 | default fields out of order from the subclass. 171 | """ 172 | raise ValueError("Field 'name' is required") 173 | 174 | 175 | @dataclass 176 | class Device(TraitTypes): 177 | """Class that represents a device object in the Google Nest SDM API.""" 178 | 179 | name: str = field(default_factory=_name_required) 180 | """Resource name of the device such as 'enterprises/XYZ/devices/123'.""" 181 | 182 | type: str | None = None 183 | """Type of device for display purposes. 184 | 185 | The device type should not be used to deduce or infer functionality of 186 | the actual device it is assigned to. Instead, use the returned traits for 187 | the device. 188 | """ 189 | 190 | relations: list[ParentRelation] = field( 191 | metadata=field_options(alias="parentRelations"), default_factory=list 192 | ) 193 | """Represents the parent structure or room of the device.""" 194 | 195 | _auth: AbstractAuth = field(init=False, metadata={"serialize": "omit"}) 196 | _diagnostics: Diagnostics = field(init=False, metadata={"serialize": "omit"}) 197 | _event_media_manager: EventMediaManager = field( 198 | init=False, metadata={"serialize": "omit"} 199 | ) 200 | _callbacks: list[Callable[[EventMessage], Awaitable[None]]] = field( 201 | init=False, metadata={"serialize": "omit"}, default_factory=list 202 | ) 203 | _trait_event_ts: dict[str, datetime.datetime] = field( 204 | init=False, metadata={"serialize": "omit"}, default_factory=dict 205 | ) 206 | 207 | @staticmethod 208 | def MakeDevice(raw_data: dict[str, Any], auth: AbstractAuth) -> Device: 209 | """Create a device with the appropriate traits.""" 210 | 211 | # Hack for incorrect nest API response values 212 | if (type := raw_data.get("type")) and type == "sdm.devices.types.DOORBELL": 213 | if TRAITS not in raw_data: 214 | raw_data[TRAITS] = {} 215 | raw_data[TRAITS][DoorbellChimeTrait.NAME] = {} 216 | 217 | device: Device = Device.parse_trait_object(raw_data) 218 | device._auth = auth 219 | device._diagnostics = Diagnostics() 220 | cmd = Command(raw_data["name"], auth, device._diagnostics.subkey("command")) 221 | for trait in device.traits.values(): 222 | if hasattr(trait, "_cmd"): 223 | trait._cmd = cmd 224 | 225 | event_traits = { 226 | trait.EVENT_NAME 227 | for trait in device.traits.values() 228 | if hasattr(trait, "EVENT_NAME") 229 | } 230 | device._event_media_manager = EventMediaManager( 231 | device.name or "", 232 | device.traits, 233 | event_traits, 234 | diagnostics=device._diagnostics.subkey("event_media"), 235 | ) 236 | return device 237 | 238 | def add_update_listener(self, target: Callable[[], None]) -> Callable[[], None]: 239 | """Register a simple event listener notified on updates. 240 | 241 | This will not block on media being fetched. To wait for media, use 242 | the callback form the `EventMediaManager`. 243 | 244 | The return value is a callable that will unregister the callback. 245 | """ 246 | 247 | async def handle_event(event_message: EventMessage) -> None: 248 | target() 249 | 250 | return self.add_event_callback(handle_event) 251 | 252 | def add_event_callback( 253 | self, target: Callable[[EventMessage], Awaitable[None]] 254 | ) -> Callable[[], None]: 255 | """Register an event callback for updates to this device. 256 | 257 | This will not block on media being fetched. To wait for media, use 258 | the callback form the `EventMediaManager`. 259 | 260 | The return value is a callable that will unregister the callback. 261 | """ 262 | self._callbacks.append(target) 263 | 264 | def remove_callback() -> None: 265 | """Remove the event_callback.""" 266 | self._callbacks.remove(target) 267 | 268 | return remove_callback 269 | 270 | async def async_handle_event(self, event_message: EventMessage) -> None: 271 | """Process an event from the pubsub subscriber. 272 | 273 | This will invoke any directly registered callbacks (before fetching media) 274 | as well as any callbacks registered with the event media manager that 275 | fire post-media. 276 | """ 277 | _LOGGER.debug( 278 | "Processing update %s @ %s", event_message.event_id, event_message.timestamp 279 | ) 280 | if not event_message.resource_update_name: 281 | raise EventProcessingError("Event was not resource update event") 282 | if self.name != event_message.resource_update_name: 283 | raise EventProcessingError( 284 | f"Mismatch {self.name} != {event_message.resource_update_name}" 285 | ) 286 | self._async_handle_traits(event_message) 287 | for callback in self._callbacks: 288 | await callback(event_message) 289 | await self._event_media_manager.async_handle_events(event_message) 290 | 291 | def _async_handle_traits(self, event_message: EventMessage) -> None: 292 | traits = event_message.resource_update_traits 293 | if not traits: 294 | return 295 | _LOGGER.debug("Trait update %s", traits) 296 | # Parse the traits using a separate object, then overwrite 297 | # each present field with an updated copy of the original trait with 298 | # the new fields merged in. 299 | parsed_traits = TraitTypes.parse_trait_object({TRAITS: traits}) 300 | for trait_field in fields(parsed_traits): 301 | if ( 302 | (alias := trait_field.metadata.get("alias")) is None 303 | or not alias.startswith(SDM_PREFIX) 304 | or not (new := getattr(parsed_traits, trait_field.name)) 305 | ): 306 | continue 307 | # Discard updates to traits that are newer than the update 308 | if ( 309 | self._trait_event_ts 310 | and (ts := self._trait_event_ts.get(trait_field.name)) 311 | and ts > event_message.timestamp 312 | ): 313 | _LOGGER.debug("Discarding stale update (%s)", event_message.timestamp) 314 | continue 315 | 316 | # Only merge updates into existing models, updating the existing 317 | # fields present in the update trait 318 | if not (existing := getattr(self, trait_field.name)): 319 | continue 320 | for k, v in asdict(new).items(): 321 | if v is not None: 322 | setattr(existing, k, v) 323 | self._trait_event_ts[trait_field.name] = event_message.timestamp 324 | 325 | @property 326 | def event_media_manager(self) -> EventMediaManager: 327 | return self._event_media_manager 328 | 329 | @property 330 | def parent_relations(self) -> dict: 331 | """Room or structure for the device.""" 332 | return {relation.parent: relation.display_name for relation in self.relations} 333 | 334 | def delete_relation(self, parent: str) -> None: 335 | """Remove a device relationship with the parent.""" 336 | self.relations = [ 337 | relation for relation in self.relations if relation.parent != parent 338 | ] 339 | 340 | def create_relation(self, relation: ParentRelation) -> None: 341 | """Add a new device relation.""" 342 | self.relations.append(relation) 343 | 344 | def get_diagnostics(self) -> dict[str, Any]: 345 | return { 346 | "data": redact_data(self.raw_data), 347 | **self._diagnostics.as_dict(), 348 | } 349 | 350 | class Config(TraitTypes.Config): 351 | serialization_strategy = { 352 | list[ParentRelation]: ParentRelationsSerializationStrategy(), 353 | } 354 | -------------------------------------------------------------------------------- /google_nest_sdm/device_manager.py: -------------------------------------------------------------------------------- 1 | """Device Manager keeps track of the current state of all devices.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Awaitable, Callable, Dict 6 | 7 | from .device import Device, ParentRelation 8 | from .event import EventMessage, RelationUpdate 9 | from .event_media import CachePolicy 10 | from .structure import Structure 11 | 12 | 13 | class DeviceManager: 14 | """DeviceManager holds current state of all devices.""" 15 | 16 | def __init__(self, cache_policy: CachePolicy | None = None) -> None: 17 | """Initialize DeviceManager.""" 18 | self._devices: Dict[str, Device] = {} 19 | self._structures: Dict[str, Structure] = {} 20 | self._cache_policy = cache_policy if cache_policy else CachePolicy() 21 | self._callback: Callable[[EventMessage], Awaitable[None]] | None = None 22 | 23 | @property 24 | def devices(self) -> Dict[str, Device]: 25 | """Return current state of devices.""" 26 | return self._devices 27 | 28 | @property 29 | def structures(self) -> Dict[str, Structure]: 30 | """Return current state of structures.""" 31 | return self._structures 32 | 33 | def add_device(self, device: Device) -> None: 34 | """Track the specified device.""" 35 | assert device.name 36 | self._devices[device.name] = device 37 | # Share a single cache policy across all devices 38 | device.event_media_manager.cache_policy = self._cache_policy 39 | if self._callback: 40 | device.event_media_manager.set_update_callback(self._callback) 41 | 42 | def add_structure(self, structure: Structure) -> None: 43 | """Track the specified device.""" 44 | assert structure.name 45 | self._structures[structure.name] = structure 46 | 47 | @property 48 | def cache_policy(self) -> CachePolicy: 49 | """Return cache policy shared by device EventMediaManager objects.""" 50 | return self._cache_policy 51 | 52 | def set_update_callback( 53 | self, target: Callable[[EventMessage], Awaitable[None]] 54 | ) -> None: 55 | """Register a callback invoked when new messages are received. 56 | 57 | If the event is associated with media, then the callback will only 58 | be invoked once the media has been fetched. 59 | """ 60 | self._callback = target 61 | for device in self._devices.values(): 62 | device.event_media_manager.set_update_callback(target) 63 | 64 | async def async_handle_event(self, event_message: EventMessage) -> None: 65 | """Handle a new message received.""" 66 | if event_message.relation_update: 67 | self._handle_device_relation(event_message.relation_update) 68 | if self._callback: 69 | await self._callback(event_message) 70 | return 71 | 72 | if event_message.resource_update_name: 73 | device_id = event_message.resource_update_name 74 | if device_id in self._devices: 75 | device = self._devices[device_id] 76 | await device.async_handle_event(event_message) 77 | 78 | def _structure_name(self, relation_subject: str) -> str: 79 | if relation_subject in self._structures: 80 | structure = self._structures[relation_subject] 81 | for trait in [structure.info, structure.room_info]: 82 | if trait and trait.custom_name: 83 | return trait.custom_name 84 | return "Unknown" 85 | 86 | def _handle_device_relation(self, relation: RelationUpdate) -> None: 87 | if relation.object not in self._devices: 88 | return 89 | 90 | device = self._devices[relation.object] 91 | if relation.type == "DELETED": 92 | # Delete device from room/structure 93 | device.delete_relation(relation.subject) 94 | 95 | if relation.type == "UPDATED" or relation.type == "CREATED": 96 | # Device moved to a room 97 | assert relation.subject 98 | device.create_relation( 99 | ParentRelation.from_dict( 100 | { 101 | "parent": relation.subject, 102 | "displayName": self._structure_name(relation.subject), 103 | } 104 | ) 105 | ) 106 | -------------------------------------------------------------------------------- /google_nest_sdm/device_traits.py: -------------------------------------------------------------------------------- 1 | """Library for traits about devices.""" 2 | 3 | import datetime 4 | from typing import Any, Dict, ClassVar 5 | from dataclasses import dataclass, field 6 | 7 | import aiohttp 8 | from mashumaro import field_options, DataClassDictMixin 9 | 10 | from .traits import CommandDataClass, TraitType 11 | 12 | 13 | @dataclass 14 | class ConnectivityTrait(DataClassDictMixin): 15 | """This trait belongs to any device that has connectivity information.""" 16 | 17 | NAME: ClassVar[TraitType] = TraitType.CONNECTIVITY 18 | 19 | status: str 20 | """Device connectivity status. 21 | 22 | Return: 23 | "OFFLINE", "ONLINE" 24 | """ 25 | 26 | 27 | @dataclass 28 | class FanTrait(DataClassDictMixin, CommandDataClass): 29 | """This trait belongs to any device that can control the fan.""" 30 | 31 | NAME: ClassVar[TraitType] = TraitType.FAN 32 | 33 | timer_mode: str | None = field( 34 | metadata=field_options(alias="timerMode"), default=None 35 | ) 36 | """Timer mode for the fan. 37 | 38 | Return: 39 | "ON", "OFF" 40 | """ 41 | 42 | timer_timeout: datetime.datetime | None = field( 43 | metadata=field_options(alias="timerTimeout"), default=None 44 | ) 45 | 46 | async def set_timer( 47 | self, timer_mode: str, duration: int | None = None 48 | ) -> aiohttp.ClientResponse: 49 | """Change the fan timer.""" 50 | data: Dict[str, Any] = { 51 | "command": "sdm.devices.commands.Fan.SetTimer", 52 | "params": { 53 | "timerMode": timer_mode, 54 | }, 55 | } 56 | if duration: 57 | data["params"]["duration"] = f"{duration}s" 58 | return await self.cmd.execute(data) 59 | 60 | 61 | @dataclass 62 | class InfoTrait(DataClassDictMixin): 63 | """This trait belongs to any device for device-related information.""" 64 | 65 | NAME: ClassVar[TraitType] = TraitType.INFO 66 | 67 | custom_name: str | None = field( 68 | metadata=field_options(alias="customName"), default=None 69 | ) 70 | """Name of the device.""" 71 | 72 | 73 | @dataclass 74 | class HumidityTrait(DataClassDictMixin): 75 | """This trait belongs to any device that has a sensor to measure humidity.""" 76 | 77 | NAME: ClassVar[TraitType] = TraitType.HUMIDITY 78 | 79 | ambient_humidity_percent: float = field( 80 | metadata=field_options(alias="ambientHumidityPercent") 81 | ) 82 | """Percent humidity, measured at the device.""" 83 | 84 | 85 | @dataclass 86 | class TemperatureTrait(DataClassDictMixin): 87 | """This trait belongs to any device that has a sensor to measure temperature.""" 88 | 89 | NAME: ClassVar[TraitType] = TraitType.TEMPERATURE 90 | 91 | ambient_temperature_celsius: float = field( 92 | metadata=field_options(alias="ambientTemperatureCelsius") 93 | ) 94 | """Percent humidity, measured at the device.""" 95 | -------------------------------------------------------------------------------- /google_nest_sdm/diagnostics.py: -------------------------------------------------------------------------------- 1 | """Diagnostics for debugging.""" 2 | 3 | from __future__ import annotations 4 | 5 | import time 6 | from collections import Counter 7 | from collections.abc import Mapping 8 | from contextlib import contextmanager 9 | from typing import Any, Generator, TypeVar, cast 10 | 11 | __all__ = [ 12 | "get_diagnostics", 13 | ] 14 | 15 | 16 | class Diagnostics: 17 | """Information for the library.""" 18 | 19 | def __init__(self) -> None: 20 | """Initialize Diagnostics.""" 21 | self._counter: Counter = Counter() 22 | self._subkeys: dict[str, Diagnostics] = {} 23 | 24 | def increment(self, key: str, count: int = 1) -> None: 25 | """Increment a counter for the specified key/event.""" 26 | self._counter.update(Counter({key: count})) 27 | 28 | def elapsed(self, key_prefix: str, elapsed_ms: int = 1) -> None: 29 | """Track a latency event for the specified key/event prefix.""" 30 | self.increment(f"{key_prefix}_count", 1) 31 | self.increment(f"{key_prefix}_sum", elapsed_ms) 32 | 33 | def as_dict(self) -> Mapping[str, Any]: 34 | """Return diagnostics as a debug dictionary.""" 35 | data: dict[str, Any] = {k: self._counter[k] for k in self._counter} 36 | for k, d in self._subkeys.items(): 37 | v = d.as_dict() 38 | if not v: 39 | continue 40 | data[k] = v 41 | return data 42 | 43 | def subkey(self, key: str) -> Diagnostics: 44 | """Return sub-Diagnositics object with the specified subkey.""" 45 | if key not in self._subkeys: 46 | self._subkeys[key] = Diagnostics() 47 | return self._subkeys[key] 48 | 49 | @contextmanager 50 | def timer(self, key_prefix: str) -> Generator[None, None, None]: 51 | """A context manager that records the timing of operations as a diagnostic.""" 52 | start = time.perf_counter() 53 | try: 54 | yield 55 | finally: 56 | end = time.perf_counter() 57 | ms = int((end - start) * 1000) 58 | self.elapsed(key_prefix, ms) 59 | 60 | def reset(self) -> None: 61 | """Clear all diagnostics, for testing.""" 62 | self._counter = Counter() 63 | for d in self._subkeys.values(): 64 | d.reset() 65 | 66 | 67 | SUBSCRIBER_DIAGNOSTICS = Diagnostics() 68 | EVENT_DIAGNOSTICS = Diagnostics() 69 | EVENT_MEDIA_DIAGNOSTICS = Diagnostics() 70 | STREAMING_MANAGER_DIAGNOSTICS = Diagnostics() 71 | 72 | MAP = { 73 | "subscriber": SUBSCRIBER_DIAGNOSTICS, 74 | "event": EVENT_DIAGNOSTICS, 75 | "event_media": EVENT_MEDIA_DIAGNOSTICS, 76 | "streaming_manager": STREAMING_MANAGER_DIAGNOSTICS, 77 | } 78 | 79 | 80 | def reset() -> None: 81 | """Clear all diagnostics, for testing.""" 82 | for diagnostics in MAP.values(): 83 | diagnostics.reset() 84 | 85 | 86 | def get_diagnostics() -> dict[str, Any]: 87 | """Produce diagnostics information for the library.""" 88 | return {k: v.as_dict() for (k, v) in MAP.items() if v.as_dict()} 89 | 90 | 91 | REDACT_KEYS = { 92 | "name", 93 | "custom_name", 94 | "displayName", 95 | "parent", 96 | "assignee", 97 | "subject", 98 | "object", 99 | "userId", 100 | "resourceGroup", 101 | "eventId", 102 | "eventSessionId", 103 | "eventThreadId", 104 | } 105 | REDACTED = "**REDACTED**" 106 | 107 | 108 | T = TypeVar("T") 109 | 110 | 111 | def redact_data(data: T) -> T | dict | list: 112 | """Redact sensitive data in a dict.""" 113 | if not isinstance(data, (Mapping, list)): 114 | return data 115 | 116 | if isinstance(data, list): 117 | return cast(T, [redact_data(item) for item in data]) 118 | 119 | redacted = {**data} 120 | 121 | for key, value in redacted.items(): 122 | if key in REDACT_KEYS: 123 | redacted[key] = REDACTED 124 | elif isinstance(value, dict): 125 | redacted[key] = redact_data(value) 126 | elif isinstance(value, list): 127 | redacted[key] = [redact_data(item) for item in value] 128 | 129 | return redacted 130 | -------------------------------------------------------------------------------- /google_nest_sdm/doorbell_traits.py: -------------------------------------------------------------------------------- 1 | """Traits belonging to doorbell devices.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass 6 | import logging 7 | from typing import ClassVar 8 | 9 | from .event import DoorbellChimeEvent, EventType 10 | from .traits import TraitType 11 | 12 | _LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | @dataclass 16 | class DoorbellChimeTrait: 17 | """For any device that supports a doorbell chime and related press events.""" 18 | 19 | NAME: ClassVar[TraitType] = TraitType.DOORBELL_CHIME 20 | EVENT_NAME: ClassVar[EventType] = DoorbellChimeEvent.NAME 21 | -------------------------------------------------------------------------------- /google_nest_sdm/exceptions.py: -------------------------------------------------------------------------------- 1 | """Library for exceptions using the Google Nest SDM API and subscriber.""" 2 | 3 | 4 | class GoogleNestException(Exception): 5 | """Base class for all client exceptions.""" 6 | 7 | 8 | class SubscriberException(GoogleNestException): 9 | """Raised during problems subscribing to events and updates.""" 10 | 11 | 12 | class ApiException(GoogleNestException): 13 | """Raised during problems talking to the API.""" 14 | 15 | 16 | class AuthException(ApiException): 17 | """Raised due to auth problems talking to API or subscriber.""" 18 | 19 | 20 | class NotFoundException(ApiException): 21 | """Raised when the API returns an error that a resource was not found.""" 22 | 23 | 24 | class ApiForbiddenException(ApiException): 25 | """Raised when the user is not authorized to perform a specific function.""" 26 | 27 | 28 | class ConfigurationException(GoogleNestException): 29 | """Raised due to misconfiguration problems.""" 30 | 31 | 32 | class DecodeException(GoogleNestException): 33 | """Raised when failing to decode a token.""" 34 | 35 | 36 | class TranscodeException(GoogleNestException): 37 | """Raised when failing to transcode media.""" 38 | -------------------------------------------------------------------------------- /google_nest_sdm/google_nest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | """Command line tool for the Google Nest Smart Device Management API. 4 | 5 | You must configure your device as described: 6 | https://developers.google.com/nest/device-access/get-started 7 | 8 | which will give you a project_id, client_id, and client_secret. This tool 9 | will do a one time setup to get an access token, then from there on will 10 | cache the token on local disk. 11 | 12 | Once authenticated, you can run commands like: 13 | 14 | $ google_nest --project_id= list 15 | $ google_nest --project_id= get 16 | """ 17 | 18 | import argparse 19 | import asyncio 20 | import errno 21 | import json 22 | import logging 23 | import os 24 | import pickle 25 | from typing import cast 26 | 27 | import yaml 28 | from aiohttp import ClientSession 29 | from google.auth.credentials import Credentials 30 | from google.auth.transport.requests import Request 31 | from google_auth_oauthlib.flow import InstalledAppFlow 32 | 33 | from .auth import AbstractAuth 34 | from .camera_traits import CameraLiveStreamTrait 35 | from .device import Device 36 | from .event import EventMessage 37 | from .google_nest_api import GoogleNestAPI 38 | from .google_nest_subscriber import ( 39 | API_URL, 40 | OAUTH2_AUTHORIZE_FORMAT, 41 | OAUTH2_TOKEN, 42 | SDM_SCOPES, 43 | GoogleNestSubscriber, 44 | ) 45 | from .structure import Structure 46 | from .thermostat_traits import ( 47 | ThermostatEcoTrait, 48 | ThermostatModeTrait, 49 | ThermostatTemperatureSetpointTrait, 50 | ) 51 | 52 | # Define command line arguments 53 | parser = argparse.ArgumentParser( 54 | description="Command line tool for Google Nest SDM API" 55 | ) 56 | parser.add_argument("--project_id", required=True, help="Device Access program id") 57 | parser.add_argument("--client_id", help="OAuth credentials client_id") 58 | parser.add_argument("--client_secret", help="OAuth credentials client_secret") 59 | parser.add_argument( 60 | "--token_cache", 61 | help="File storage for long lived creds", 62 | default="~/.config/google_nest/token_cache", 63 | ) 64 | parser.add_argument( 65 | "-v", "--verbose", help="Increase output verbosity", action="store_true" 66 | ) 67 | parser.add_argument( 68 | "--output_type", 69 | type=str, 70 | choices=["json", "yaml"], 71 | help="Change the output type from json or yaml (default).", 72 | default="yaml", 73 | ) 74 | 75 | cmd_parser = parser.add_subparsers(dest="command", required=True) 76 | list_structures_parser = cmd_parser.add_parser("list_structures") 77 | list_devices_parser = cmd_parser.add_parser("list_devices") 78 | get_structure_parser = cmd_parser.add_parser("get_structure") 79 | get_structure_parser.add_argument("structure_id") 80 | get_device_parser = cmd_parser.add_parser("get_device") 81 | get_device_parser.add_argument("device_id") 82 | set_mode_parser = cmd_parser.add_parser( 83 | "set_mode", description="Change the thermostat mode." 84 | ) 85 | set_mode_parser.add_argument("device_id") 86 | set_mode_parser.add_argument( 87 | "mode", 88 | help="The mode to change the thermostat to.", 89 | choices=["MANUAL_ECO", "HEAT", "COOL", "HEATCOOL", "OFF"], 90 | ) 91 | set_heat_parser = cmd_parser.add_parser( 92 | "set_heat", description="Sets the target temperature when in HEAT mode." 93 | ) 94 | set_heat_parser.add_argument("device_id") 95 | set_heat_parser.add_argument("heat", type=float) 96 | set_cool_parser = cmd_parser.add_parser( 97 | "set_cool", help="Sets the target temperature when in COOL mode." 98 | ) 99 | set_cool_parser.add_argument("device_id") 100 | set_cool_parser.add_argument( 101 | "cool", 102 | type=float, 103 | help="The target temperature to set when the thermostat is in COOL mode.", 104 | ) 105 | set_range_parser = cmd_parser.add_parser( 106 | "set_range", help="Sets the min/max temperature when in HEATCOOL mode." 107 | ) 108 | set_range_parser.add_argument("device_id") 109 | set_range_parser.add_argument( 110 | "heat", type=float, help="The minimum target temperature to set." 111 | ) 112 | set_range_parser.add_argument( 113 | "cool", type=float, help="The maximum target temperature to set." 114 | ) 115 | generate_rtsp_stream_parser = cmd_parser.add_parser("generate_rtsp_stream") 116 | generate_rtsp_stream_parser.add_argument("device_id") 117 | generate_web_rtc_stream_parser = cmd_parser.add_parser("generate_web_rtc_stream") 118 | generate_web_rtc_stream_parser.add_argument("device_id") 119 | generate_web_rtc_stream_parser.add_argument("offer_file") 120 | subscribe_parser = cmd_parser.add_parser("subscribe") 121 | subscribe_parser.add_argument("subscription_id") 122 | subscribe_parser.add_argument("device_id", nargs="?") 123 | 124 | 125 | class Auth(AbstractAuth): 126 | """Implementation of AbstractAuth that uses the token cache.""" 127 | 128 | def __init__( 129 | self, 130 | websession: ClientSession, 131 | user_creds: Credentials, 132 | api_url: str, 133 | ): 134 | """Initialize Google Nest Device Access auth.""" 135 | super().__init__(websession, api_url) 136 | self._user_creds = user_creds 137 | 138 | async def async_get_access_token(self) -> str: 139 | """Return a valid access token.""" 140 | return cast(str, self._user_creds.token) 141 | 142 | async def async_get_creds(self) -> Credentials: 143 | """Return valid OAuth creds.""" 144 | return self._user_creds 145 | 146 | 147 | def CreateCreds(args: argparse.Namespace) -> Credentials: 148 | """Run an interactive flow to get OAuth creds.""" 149 | creds = None 150 | token_cache = os.path.expanduser(args.token_cache) 151 | if os.path.exists(token_cache): 152 | with open(token_cache, "rb") as token: 153 | creds = pickle.load(token) 154 | 155 | # If there are no (valid) credentials available, let the user log in. 156 | if not creds or not creds.valid: 157 | if creds and creds.expired and creds.refresh_token: 158 | creds.refresh(Request()) 159 | else: 160 | if not args.client_id or not args.client_secret: 161 | raise ValueError("Required flag --client_id or --client_secret missing") 162 | client_config = { 163 | "installed": { 164 | "client_id": args.client_id, 165 | "client_secret": args.client_secret, 166 | "auth_uri": OAUTH2_AUTHORIZE_FORMAT.format( 167 | project_id=args.project_id 168 | ), 169 | "token_uri": OAUTH2_TOKEN, 170 | }, 171 | } 172 | app_flow = InstalledAppFlow.from_client_config( 173 | client_config, scopes=SDM_SCOPES 174 | ) 175 | creds = app_flow.run_local_server() 176 | # Save the credentials for the next run 177 | if not os.path.exists(os.path.dirname(token_cache)): 178 | try: 179 | os.makedirs(os.path.dirname(token_cache)) 180 | except OSError as exc: # Guard against race condition 181 | if exc.errno != errno.EEXIST: 182 | raise 183 | with open(token_cache, "wb") as token: 184 | pickle.dump(creds, token) 185 | return creds 186 | 187 | 188 | def PrintStructure(structure: Structure, output_type: str) -> None: 189 | """Print the structure.""" 190 | if output_type == "json": 191 | print(json.dumps(structure.raw_data)) 192 | else: 193 | print(yaml.dump(structure.raw_data)) 194 | 195 | 196 | def PrintDevice(device: Device, output_type: str) -> None: 197 | """Print the device.""" 198 | if output_type == "json": 199 | print(json.dumps(device.raw_data)) 200 | else: 201 | print(yaml.dump(device.raw_data)) 202 | 203 | 204 | class SubscribeCallback: 205 | """Print the event message.""" 206 | 207 | def __init__(self, output_type: str | None = None) -> None: 208 | """Initialize SubscribeCallback.""" 209 | self._output_type = output_type 210 | 211 | async def async_handle_event(self, event_message: EventMessage) -> None: 212 | """Handle an EventMessage.""" 213 | if self._output_type == "json": 214 | print(json.dumps(event_message.raw_data)) 215 | else: 216 | print(yaml.dump(event_message.raw_data)) 217 | 218 | 219 | class DeviceWatcherCallback: 220 | """Print the event message.""" 221 | 222 | def __init__(self, device: Device, output_type: str) -> None: 223 | """Initialize DeviceWatcherCallback.""" 224 | self._device = device 225 | self._output_type = output_type 226 | 227 | async def async_handle_event(self, event_message: EventMessage) -> None: 228 | """Handle an EventMessage.""" 229 | print(f"event_id: {event_message.event_id}") 230 | print("Current device state:") 231 | PrintDevice(self._device, self._output_type) 232 | print("") 233 | 234 | 235 | async def RunTool(args: argparse.Namespace, user_creds: Credentials) -> None: 236 | """Run the command.""" 237 | async with ClientSession() as client: 238 | auth = Auth(client, user_creds, API_URL) 239 | api = GoogleNestAPI(auth, args.project_id) 240 | 241 | if args.command == "list_structures": 242 | structures: list[Structure] = await api.async_get_structures() 243 | for s in structures: 244 | PrintStructure(s, args.output_type) 245 | return 246 | 247 | if args.command == "get_structure": 248 | structure: Structure | None = await api.async_get_structure( 249 | args.structure_id 250 | ) 251 | assert structure 252 | PrintStructure(structure, args.output_type) 253 | return 254 | 255 | if args.command == "list_devices": 256 | devices = await api.async_get_devices() 257 | for d in devices: 258 | PrintDevice(d, args.output_type) 259 | return 260 | 261 | if args.command == "subscribe": 262 | logging.info("Subscription: %s", args.subscription_id) 263 | subscriber = GoogleNestSubscriber( 264 | auth, args.project_id, args.subscription_id 265 | ) 266 | if args.device_id: 267 | device_manager = await subscriber.async_get_device_manager() 268 | dev = device_manager.devices[args.device_id] 269 | dev_callback = DeviceWatcherCallback(dev, args.output_type) 270 | dev.add_event_callback(dev_callback.async_handle_event) 271 | else: 272 | sub_callback = SubscribeCallback(args.output_type) 273 | subscriber.set_update_callback(sub_callback.async_handle_event) 274 | unsub = await subscriber.start_async() 275 | try: 276 | while True: 277 | await asyncio.sleep(10) 278 | except KeyboardInterrupt: 279 | unsub() 280 | 281 | # All other commands require a device_id 282 | device: Device | None = await api.async_get_device(args.device_id) 283 | assert device 284 | 285 | if args.command == "get_device": 286 | PrintDevice(device, args.output_type) 287 | 288 | if args.command == "set_mode": 289 | mode = args.mode 290 | trait = device.traits[ThermostatModeTrait.NAME] 291 | if mode == "MANUAL_ECO": 292 | trait = device.traits[ThermostatEcoTrait.NAME] 293 | resp = await trait.set_mode(mode) 294 | print(await resp.text()) 295 | 296 | if args.command == "set_heat": 297 | trait = device.traits[ThermostatTemperatureSetpointTrait.NAME] 298 | resp = await trait.set_heat(args.heat) 299 | print(await resp.text()) 300 | 301 | if args.command == "set_cool": 302 | trait = device.traits[ThermostatTemperatureSetpointTrait.NAME] 303 | resp = await trait.set_cool(args.cool) 304 | print(await resp.text()) 305 | 306 | if args.command == "set_range": 307 | trait = device.traits[ThermostatTemperatureSetpointTrait.NAME] 308 | resp = await trait.set_range(args.heat, args.cool) 309 | print(await resp.text()) 310 | 311 | if args.command == "generate_rtsp_stream": 312 | trait = device.traits[CameraLiveStreamTrait.NAME] 313 | stream = await trait.generate_rtsp_stream() 314 | print(f"URL: {stream.rtsp_stream_url}") 315 | print(f"Stream Token: {stream.stream_token}") 316 | print(f"Expires At: {stream.expires_at}") 317 | 318 | if args.command == "generate_web_rtc_stream": 319 | trait = device.traits[CameraLiveStreamTrait.NAME] 320 | offer_sdp = None 321 | if args.offer_file: 322 | f = open(args.offer_file, "r") 323 | offer_sdp = f.read() 324 | stream = await trait.generate_web_rtc_stream(offer_sdp) 325 | print(f"Answer SDP: {stream.answer_sdp}") 326 | print(f"Media Session Id: {stream.media_session_id}") 327 | print(f"Expires At: {stream.expires_at}") 328 | 329 | 330 | def main() -> None: 331 | """Nest command line tool.""" 332 | args: argparse.Namespace = parser.parse_args() 333 | if args.verbose: 334 | logging.basicConfig(level=logging.DEBUG) 335 | user_creds = CreateCreds(args) 336 | loop = asyncio.get_event_loop() 337 | loop.run_until_complete(RunTool(args, user_creds)) 338 | loop.close() 339 | 340 | 341 | if __name__ == "__main__": 342 | main() 343 | -------------------------------------------------------------------------------- /google_nest_sdm/google_nest_api.py: -------------------------------------------------------------------------------- 1 | """Library to access the Smart Device Management API.""" 2 | 3 | from .auth import AbstractAuth 4 | from .device import Device 5 | from .structure import Structure 6 | 7 | __all__ = ["GoogleNestAPI"] 8 | 9 | STRUCTURES = "structures" 10 | DEVICES = "devices" 11 | NAME = "name" 12 | 13 | 14 | class GoogleNestAPI: 15 | """Client library to communicate with the Google Nest SDM API.""" 16 | 17 | def __init__(self, auth: AbstractAuth, project_id: str): 18 | """Initialize the API and store the auth so we can make requests.""" 19 | self._auth = auth 20 | self._project_id = project_id 21 | 22 | @property 23 | def _structures_url(self) -> str: 24 | return f"enterprises/{self._project_id}/structures" 25 | 26 | async def async_get_structures(self) -> list[Structure]: 27 | """Return the structures.""" 28 | response_data = await self._auth.get_json(self._structures_url) 29 | if STRUCTURES not in response_data: 30 | return [] 31 | structures = response_data[STRUCTURES] 32 | return [ 33 | Structure.MakeStructure(structure_data) for structure_data in structures 34 | ] 35 | 36 | async def async_get_structure(self, structure_id: str) -> Structure | None: 37 | """Return a structure device.""" 38 | data = await self._auth.get_json(f"{self._structures_url}/{structure_id}") 39 | if NAME not in data: 40 | return None 41 | return Structure.MakeStructure(data) 42 | 43 | @property 44 | def _devices_url(self) -> str: 45 | return f"enterprises/{self._project_id}/devices" 46 | 47 | async def async_get_devices(self) -> list[Device]: 48 | """Return the devices.""" 49 | response_data = await self._auth.get_json(self._devices_url) 50 | if DEVICES not in response_data: 51 | return [] 52 | devices = response_data[DEVICES] 53 | return [Device.MakeDevice(device_data, self._auth) for device_data in devices] 54 | 55 | async def async_get_device(self, device_id: str) -> Device | None: 56 | """Return a specific device.""" 57 | data = await self._auth.get_json(f"{self._devices_url}/{device_id}") 58 | if NAME not in data: 59 | return None 60 | return Device.MakeDevice(data, self._auth) 61 | -------------------------------------------------------------------------------- /google_nest_sdm/google_nest_subscriber.py: -------------------------------------------------------------------------------- 1 | """Subscriber for the Smart Device Management event based API.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import enum 7 | import logging 8 | import re 9 | import time 10 | from typing import Awaitable, Callable 11 | 12 | 13 | from .auth import AbstractAuth 14 | from .device_manager import DeviceManager 15 | from .diagnostics import SUBSCRIBER_DIAGNOSTICS as DIAGNOSTICS 16 | from .event import EventMessage 17 | from .event_media import CachePolicy 18 | from .exceptions import ( 19 | ConfigurationException, 20 | ApiException, 21 | ) 22 | from .google_nest_api import GoogleNestAPI 23 | from .streaming_manager import StreamingManager, Message 24 | 25 | __all__ = [ 26 | "GoogleNestSubscriber", 27 | "ApiEnv", 28 | ] 29 | 30 | 31 | _LOGGER = logging.getLogger(__name__) 32 | 33 | # Used to catch invalid subscriber id 34 | EXPECTED_SUBSCRIBER_REGEXP = re.compile("projects/.*/subscriptions/.*") 35 | 36 | MESSAGE_ACK_TIMEOUT_SECONDS = 30.0 37 | 38 | # Note: Users of non-prod instances will have to manually configure a topic 39 | TOPIC_FORMAT = "projects/sdm-prod/topics/enterprise-{project_id}" 40 | 41 | OAUTH2_AUTHORIZE_FORMAT = ( 42 | "https://nestservices.google.com/partnerconnections/{project_id}/auth" 43 | ) 44 | OAUTH2_TOKEN = "https://www.googleapis.com/oauth2/v4/token" 45 | SDM_SCOPES = [ 46 | "https://www.googleapis.com/auth/sdm.service", 47 | "https://www.googleapis.com/auth/pubsub", 48 | ] 49 | API_URL = "https://smartdevicemanagement.googleapis.com/v1" 50 | 51 | 52 | class ApiEnv(enum.Enum): 53 | PROD = (OAUTH2_AUTHORIZE_FORMAT, API_URL) 54 | PREPROD = ( 55 | "https://sdmresourcepicker-preprod.sandbox.google.com/partnerconnections/{project_id}/auth", 56 | "https://preprod-smartdevicemanagement.googleapis.com/v1", 57 | ) 58 | 59 | def __init__(self, authorize_url: str, api_url: str) -> None: 60 | """Init ApiEnv.""" 61 | self._authorize_url = authorize_url 62 | self._api_url = api_url 63 | 64 | @property 65 | def authorize_url_format(self) -> str: 66 | """OAuth Authorize url format string.""" 67 | return self._authorize_url 68 | 69 | @property 70 | def api_url(self) -> str: 71 | """API url.""" 72 | return self._api_url 73 | 74 | 75 | def get_api_env(env: str | None) -> ApiEnv: 76 | """Create an ApiEnv from a string.""" 77 | if env is None or env == "prod": 78 | return ApiEnv.PROD 79 | if env == "preprod": 80 | return ApiEnv.PREPROD 81 | raise ValueError("Invalid ApiEnv: %s" % env) 82 | 83 | 84 | def _validate_subscription_name(subscription_name: str) -> None: 85 | """Validates that a subscription name is correct. 86 | 87 | Raises ConfigurationException on failure. 88 | """ 89 | if not EXPECTED_SUBSCRIBER_REGEXP.match(subscription_name): 90 | DIAGNOSTICS.increment("subscription_name_invalid") 91 | _LOGGER.debug("Subscription name did not match pattern: %s", subscription_name) 92 | raise ConfigurationException( 93 | "Subscription misconfigured. Expected subscriber_id to " 94 | f"match '{EXPECTED_SUBSCRIBER_REGEXP.pattern}' but was " 95 | f"'{subscription_name}'" 96 | ) 97 | 98 | 99 | class GoogleNestSubscriber: 100 | """Subscribe to events from the Google Nest feed.""" 101 | 102 | def __init__( 103 | self, 104 | auth: AbstractAuth, 105 | project_id: str, 106 | subscription_name: str, 107 | ) -> None: 108 | """Initialize the subscriber for the specified topic.""" 109 | self._auth = auth 110 | self._subscription_name = subscription_name 111 | self._project_id = project_id 112 | self._api = GoogleNestAPI(auth, project_id) 113 | self._device_manager_task: asyncio.Task[DeviceManager] | None = None 114 | self._callback: Callable[[EventMessage], Awaitable[None]] | None = None 115 | self._cache_policy = CachePolicy() 116 | 117 | @property 118 | def subscription_name(self) -> str: 119 | """Return the configured subscriber name.""" 120 | return self._subscription_name 121 | 122 | @property 123 | def project_id(self) -> str: 124 | """Return the configured SDM project_id.""" 125 | return self._project_id 126 | 127 | def set_update_callback( 128 | self, target: Callable[[EventMessage], Awaitable[None]] 129 | ) -> None: 130 | """Register a callback invoked when new messages are received. 131 | 132 | If the event is associated with media, then the callback will only 133 | be invoked once the media has been fetched. 134 | """ 135 | self._callback = target 136 | if self._device_manager_task and self._device_manager_task.done(): 137 | self._device_manager_task.result().set_update_callback(target) 138 | 139 | async def start_async(self) -> Callable[[], None]: 140 | """Start the subscription. 141 | 142 | Returns a callable used to stop/cancel the subscription. Received 143 | messages are passed to the callback provided to `set_update_callback`. 144 | """ 145 | _validate_subscription_name(self._subscription_name) 146 | _LOGGER.debug("Starting subscription %s", self._subscription_name) 147 | DIAGNOSTICS.increment("start") 148 | 149 | stream = StreamingManager( 150 | auth=self._auth, 151 | subscription_name=self._subscription_name, 152 | callback=self._async_message_callback_with_timeout, 153 | ) 154 | await stream.start() 155 | return stream.stop 156 | 157 | @property 158 | def cache_policy(self) -> CachePolicy: 159 | """Return cache policy shared by device EventMediaManager objects.""" 160 | return self._cache_policy 161 | 162 | async def async_get_device_manager(self) -> DeviceManager: 163 | """Return the DeviceManger with the current state of devices.""" 164 | if not self._device_manager_task: 165 | self._device_manager_task = asyncio.create_task( 166 | self._async_create_device_manager() 167 | ) 168 | return await self._device_manager_task 169 | 170 | async def _async_create_device_manager(self) -> DeviceManager: 171 | """Create a DeviceManager, populated with initial state.""" 172 | device_manager = DeviceManager(self._cache_policy) 173 | structures = await self._api.async_get_structures() 174 | for structure in structures: 175 | device_manager.add_structure(structure) 176 | # Subscriber starts after a device fetch 177 | devices = await self._api.async_get_devices() 178 | for device in devices: 179 | device_manager.add_device(device) 180 | if self._callback: 181 | device_manager.set_update_callback(self._callback) 182 | return device_manager 183 | 184 | async def _async_message_callback_with_timeout(self, message: Message) -> None: 185 | """Handle a received message.""" 186 | try: 187 | async with asyncio.timeout(MESSAGE_ACK_TIMEOUT_SECONDS): 188 | await self._async_message_callback(message) 189 | except TimeoutError as err: 190 | DIAGNOSTICS.increment("message_ack_timeout") 191 | raise TimeoutError("Message ack timeout processing message") from err 192 | 193 | async def _async_message_callback(self, message: Message) -> None: 194 | """Handle a received message.""" 195 | event = EventMessage.create_event(message.payload, self._auth) 196 | recv = time.time() 197 | latency_ms = int((recv - event.timestamp.timestamp()) * 1000) 198 | DIAGNOSTICS.elapsed("message_received", latency_ms) 199 | # Only accept device events once the Device Manager has been loaded. 200 | # We are ok with missing messages on startup since the device manager 201 | # will do a live read. This checks for an exception to avoid throwing 202 | # inside the pubsub callback and further wedging the pubsub client library. 203 | if ( 204 | self._device_manager_task 205 | and self._device_manager_task.done() 206 | and not self._device_manager_task.exception() 207 | ): 208 | device_manager = self._device_manager_task.result() 209 | if _is_invalid_thermostat_trait_update(event): 210 | _LOGGER.debug( 211 | "Ignoring event with invalid update traits; Refreshing devices: %s", 212 | event.resource_update_traits, 213 | ) 214 | await _hack_refresh_devices(self._api, device_manager) 215 | else: 216 | await device_manager.async_handle_event(event) 217 | 218 | process_latency_ms = int((time.time() - recv) * 1000) 219 | DIAGNOSTICS.elapsed("message_processed", process_latency_ms) 220 | 221 | 222 | def _is_invalid_thermostat_trait_update(event: EventMessage) -> bool: 223 | """Return true if this is an invalid thermostat trait update.""" 224 | if ( 225 | event.resource_update_traits is not None 226 | and ( 227 | thermostat_mode := event.resource_update_traits.get( 228 | "sdm.devices.traits.ThermostatMode" 229 | ) 230 | ) 231 | and (available_modes := thermostat_mode.get("availableModes")) is not None 232 | and available_modes == ["OFF"] 233 | ): 234 | return True 235 | return False 236 | 237 | 238 | async def _hack_refresh_devices( 239 | api: GoogleNestAPI, device_manager: DeviceManager 240 | ) -> None: 241 | """Update the device manager with refreshed devices from the API.""" 242 | DIAGNOSTICS.increment("invalid-thermostat-update") 243 | try: 244 | devices = await api.async_get_devices() 245 | except ApiException: 246 | DIAGNOSTICS.increment("invalid-thermostat-update-refresh-failure") 247 | _LOGGER.debug("Failed to refresh devices after invalid message") 248 | else: 249 | DIAGNOSTICS.increment("invalid-thermostat-update-refresh-success") 250 | for device in devices: 251 | device_manager.add_device(device) 252 | -------------------------------------------------------------------------------- /google_nest_sdm/model.py: -------------------------------------------------------------------------------- 1 | """Base model for all nest trait based classes.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, fields 6 | from mashumaro import DataClassDictMixin 7 | from mashumaro.config import BaseConfig 8 | from typing import Any, Mapping, Self 9 | 10 | 11 | TRAITS = "traits" 12 | SDM_PREFIX = "sdm." 13 | 14 | 15 | @dataclass 16 | class TraitDataClass(DataClassDictMixin): 17 | """Base model for API objects that are trait based. 18 | 19 | This is meant to be subclasses by the model definitions. 20 | """ 21 | 22 | @classmethod 23 | def parse_trait_object(cls, raw_data: Mapping[str, Any]) -> Self: 24 | """Parse a new dataclass""" 25 | return cls.from_dict( 26 | { 27 | **raw_data, 28 | **raw_data.get(TRAITS, {}), 29 | } 30 | ) 31 | 32 | @property 33 | def traits(self) -> dict[str, Any]: 34 | """Return a trait mixin on None.""" 35 | return { 36 | alias: value 37 | for field in fields(self) 38 | if (alias := field.metadata.get("alias")) is not None 39 | and (value := getattr(self, field.name)) is not None 40 | and alias.startswith(SDM_PREFIX) 41 | } 42 | 43 | @property 44 | def raw_data(self) -> dict[str, Any]: 45 | """Return raw data for the object.""" 46 | result: dict[str, Any] = {} 47 | for k, v in self.to_dict(by_alias=True, omit_none=True).items(): 48 | if k.startswith(SDM_PREFIX): 49 | if "traits" not in result: 50 | result["traits"] = {} 51 | result["traits"][k] = v 52 | else: 53 | result[k] = v 54 | return result 55 | 56 | class Config(BaseConfig): 57 | code_generation_options = [ 58 | "TO_DICT_ADD_BY_ALIAS_FLAG", 59 | "TO_DICT_ADD_OMIT_NONE_FLAG", 60 | ] 61 | serialize_by_alias = True 62 | -------------------------------------------------------------------------------- /google_nest_sdm/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenporter/python-google-nest-sdm/0ed9c93b1a77cabca578b6ff73c2e313e1bd43ef/google_nest_sdm/py.typed -------------------------------------------------------------------------------- /google_nest_sdm/registry.py: -------------------------------------------------------------------------------- 1 | """Decorator for creating a registry of objects.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Callable, TypeVar, Any 6 | 7 | CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name 8 | 9 | 10 | class Registry(dict[str, Any]): 11 | """Registry of items.""" 12 | 13 | def register(self, name: str | None = None) -> Callable[[CALLABLE_T], CALLABLE_T]: 14 | """Return decorator to register item with a specific name.""" 15 | 16 | def decorator(func: CALLABLE_T) -> CALLABLE_T: 17 | """Register decorated function.""" 18 | nonlocal name 19 | if name is None: 20 | name = func.NAME # type: ignore 21 | self[name] = func 22 | return func 23 | 24 | return decorator 25 | -------------------------------------------------------------------------------- /google_nest_sdm/streaming_manager.py: -------------------------------------------------------------------------------- 1 | """Subscriber for the Smart Device Management event based API.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import datetime 7 | import logging 8 | import json 9 | from typing import Awaitable, Callable, AsyncIterable, Any, TYPE_CHECKING 10 | 11 | from google import pubsub_v1 12 | 13 | from .auth import AbstractAuth 14 | from .diagnostics import STREAMING_MANAGER_DIAGNOSTICS as DIAGNOSTICS 15 | from .exceptions import ( 16 | GoogleNestException, 17 | ) 18 | from .subscriber_client import SubscriberClient 19 | 20 | _LOGGER = logging.getLogger(__name__) 21 | 22 | MESSAGE_ACK_TIMEOUT_SECONDS = 30.0 23 | 24 | NEW_SUBSCRIBER_TIMEOUT_SECONDS = 30.0 25 | 26 | MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10) 27 | MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=10) 28 | BACKOFF_MULTIPLIER = 1.5 29 | 30 | 31 | class Message: 32 | """A message from the Pub/Sub stream.""" 33 | 34 | def __init__(self, message: pubsub_v1.types.PubsubMessage) -> None: 35 | """Initialize the message.""" 36 | self._message = message 37 | self._payload: dict[str, Any] | None = None 38 | 39 | @property 40 | def payload(self) -> dict[str, Any]: 41 | """Get the payload of the message.""" 42 | if self._payload is None: 43 | self._payload = json.loads(bytes.decode(self._message.data)) or {} 44 | return self._payload 45 | 46 | @classmethod 47 | def from_data(cls, data: dict[str, Any]) -> Message: 48 | """Create a message from an object for testing.""" 49 | return cls(encode_pubsub_message(data)) 50 | 51 | 52 | def encode_pubsub_message(data: dict[str, Any]) -> pubsub_v1.types.PubsubMessage: 53 | """Encode a message for Pub/Sub.""" 54 | return pubsub_v1.types.PubsubMessage(data=bytes(json.dumps(data), "utf-8")) 55 | 56 | 57 | class StreamingManager: 58 | """Client for the Google Nest subscriber.""" 59 | 60 | def __init__( 61 | self, 62 | auth: AbstractAuth, 63 | subscription_name: str, 64 | callback: Callable[[Message], Awaitable[None]], 65 | ) -> None: 66 | """Initialize the client.""" 67 | self._subscription_name = subscription_name 68 | self._callback = callback 69 | self._background_task: asyncio.Task | None = None 70 | self._auth = auth 71 | self._subscriber_client: SubscriberClient | None = None 72 | self._stream: AsyncIterable[pubsub_v1.types.StreamingPullResponse] | None = None 73 | self._ack_ids: list[str] = [] 74 | self._healthy = False 75 | self._backoff = MIN_BACKOFF_INTERVAL 76 | 77 | async def start(self) -> None: 78 | """Start the subscription background task and wait for initial startup.""" 79 | DIAGNOSTICS.increment("start") 80 | self._stream = await self._connect() 81 | self._healthy = True 82 | loop = asyncio.get_event_loop() 83 | self._background_task = loop.create_task(self._run_task()) 84 | 85 | @property 86 | def healthy(self) -> bool: 87 | """Return True if the subscription is healthy.""" 88 | return self._healthy 89 | 90 | def stop(self) -> None: 91 | _LOGGER.debug("Stopping subscription %s", self._subscription_name) 92 | DIAGNOSTICS.increment("stop") 93 | if self._background_task: 94 | self._background_task.cancel() 95 | self._healthy = False 96 | self._subscriber_client = None 97 | 98 | async def _run_task(self) -> None: 99 | """""" 100 | try: 101 | await self._run() 102 | except asyncio.CancelledError: 103 | _LOGGER.debug("Subscription loop cancelled") 104 | except Exception as err: 105 | _LOGGER.info("Uncaught error in subscription loop: %s", err) 106 | DIAGNOSTICS.increment("uncaught_exception") 107 | self._healthy = False 108 | 109 | async def _run(self) -> None: 110 | """Run the subscription loop.""" 111 | DIAGNOSTICS.increment("run") 112 | while True: 113 | if TYPE_CHECKING: 114 | assert self._stream is not None 115 | self._healthy = True 116 | _LOGGER.debug("Event stream connection established") 117 | try: 118 | async for response in self._stream: 119 | _LOGGER.debug( 120 | "Received %s messages", len(response.received_messages) 121 | ) 122 | # Reset backoff anytime we receive messages 123 | self._backoff = MIN_BACKOFF_INTERVAL 124 | for received_message in response.received_messages: 125 | if await self._process_message(received_message.message): 126 | self._ack_ids.append(received_message.ack_id) 127 | except GoogleNestException as err: 128 | _LOGGER.debug("Disconnected from event stream: %s", err) 129 | DIAGNOSTICS.increment("exception") 130 | self._healthy = False 131 | self._subscriber_client = None 132 | 133 | while True: 134 | _LOGGER.debug( 135 | "Reconnecting stream in %s seconds", self._backoff.total_seconds() 136 | ) 137 | await asyncio.sleep(self._backoff.total_seconds()) 138 | try: 139 | self._stream = await self._connect() 140 | break 141 | except GoogleNestException as err: 142 | _LOGGER.debug("Error connecting to event stream: %s", err) 143 | self._backoff = min( 144 | self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL 145 | ) 146 | DIAGNOSTICS.increment("backoff") 147 | 148 | async def _connect(self) -> AsyncIterable[pubsub_v1.types.StreamingPullResponse]: 149 | """Connect to the streaming pull.""" 150 | _LOGGER.debug("Connecting with streaming pull") 151 | DIAGNOSTICS.increment("connect") 152 | self._subscriber_client = SubscriberClient(self._auth, self._subscription_name) 153 | return await self._subscriber_client.streaming_pull(self.pending_ack_ids) 154 | 155 | def pending_ack_ids(self) -> list[str]: 156 | """Generate the ack IDs for the next streaming pull request and clear.""" 157 | ack_ids = [*self._ack_ids] 158 | self._ack_ids = [] 159 | return ack_ids 160 | 161 | async def _process_message(self, message: pubsub_v1.types.PubsubMessage) -> bool: 162 | """Process an incoming message from the stream.""" 163 | DIAGNOSTICS.increment("process_message") 164 | try: 165 | async with asyncio.timeout(MESSAGE_ACK_TIMEOUT_SECONDS): 166 | await self._callback(Message(message)) 167 | return True 168 | except TimeoutError as err: 169 | DIAGNOSTICS.increment("process_message_timeout") 170 | _LOGGER.info("Unexpected timeout while processing message: %s", err) 171 | return False 172 | except Exception as err: 173 | DIAGNOSTICS.increment("process_message_exception") 174 | _LOGGER.info("Uncaught error while processing message: %s", err) 175 | return False 176 | -------------------------------------------------------------------------------- /google_nest_sdm/structure.py: -------------------------------------------------------------------------------- 1 | """Traits for structures / rooms.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Any, Mapping 7 | from mashumaro import field_options 8 | 9 | from .model import TraitDataClass 10 | 11 | 12 | @dataclass 13 | class InfoTrait: 14 | """This trait belongs to any structure for structure-related information.""" 15 | 16 | custom_name: str | None = field( 17 | metadata=field_options(alias="customName"), default=None 18 | ) 19 | """Name of the structure.""" 20 | 21 | 22 | @dataclass 23 | class RoomInfoTrait: 24 | """This trait belongs to any structure for room-related information.""" 25 | 26 | custom_name: str = field(metadata=field_options(alias="customName")) 27 | """Name of the structure.""" 28 | 29 | 30 | @dataclass 31 | class Structure(TraitDataClass): 32 | """Class that represents a structure object in the Google Nest SDM API.""" 33 | 34 | name: str 35 | """Resource name of the structure e.g. 'enterprises/XYZ/structures/123'.""" 36 | 37 | info: InfoTrait | None = field( 38 | metadata=field_options(alias="sdm.structures.traits.Info"), default=None 39 | ) 40 | room_info: RoomInfoTrait | None = field( 41 | metadata=field_options(alias="sdm.structures.traits.RoomInfo"), default=None 42 | ) 43 | 44 | @classmethod 45 | def MakeStructure(cls, raw_data: Mapping[str, Any]) -> Structure: 46 | """Create a structure with the appropriate traits.""" 47 | return cls.parse_trait_object(raw_data) 48 | -------------------------------------------------------------------------------- /google_nest_sdm/subscriber_client.py: -------------------------------------------------------------------------------- 1 | """Pub/sub subscriber client library.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import logging 7 | from typing import Awaitable, Callable, AsyncIterable, Any, TypeVar 8 | from collections.abc import AsyncGenerator 9 | 10 | from aiohttp.client_exceptions import ClientError 11 | from google.api_core.exceptions import GoogleAPIError, NotFound, Unauthenticated 12 | from google.auth.exceptions import RefreshError, GoogleAuthError, TransportError 13 | from google.auth.transport.requests import Request 14 | from google import pubsub_v1 15 | from google.oauth2.credentials import Credentials 16 | 17 | from .auth import AbstractAuth 18 | from .diagnostics import SUBSCRIBER_DIAGNOSTICS as DIAGNOSTICS 19 | from .exceptions import ( 20 | AuthException, 21 | ConfigurationException, 22 | SubscriberException, 23 | ) 24 | 25 | _LOGGER = logging.getLogger(__name__) 26 | 27 | _T = TypeVar("_T") 28 | 29 | RPC_TIMEOUT_SECONDS = 30.0 30 | STREAMING_PULL_TIMEOUT_SECONDS = 55.0 31 | STREAM_ACK_TIMEOUT_SECONDS = 180 32 | STREAM_ACK_FREQUENCY_SECONDS = 90 33 | 34 | 35 | def refresh_creds(creds: Credentials) -> Credentials: 36 | """Refresh credentials. 37 | 38 | This is not part of the subscriber API, exposed only to facilitate testing. 39 | """ 40 | try: 41 | creds.refresh(Request()) 42 | except RefreshError as err: 43 | raise AuthException(f"Authentication refresh failure: {err}") from err 44 | except TransportError as err: 45 | raise SubscriberException( 46 | f"Connectivity error during authentication refresh: {err}" 47 | ) from err 48 | except GoogleAuthError as err: 49 | raise SubscriberException( 50 | f"Error during authentication refresh: {err}" 51 | ) from err 52 | return creds 53 | 54 | 55 | def exception_handler[_T: Any]( 56 | func_name: str, 57 | ) -> Callable[..., Callable[..., Awaitable[_T]]]: 58 | """Wrap a function with exception handling.""" 59 | 60 | def wrapped(func: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]: 61 | async def wrapped_func(*args: Any, **kwargs: Any) -> _T: 62 | try: 63 | return await func(*args, **kwargs) 64 | except NotFound as err: 65 | _LOGGER.debug("NotFound error in %s: %s", func_name, err) 66 | DIAGNOSTICS.increment(f"{func_name}.not_found_error") 67 | raise ConfigurationException( 68 | f"NotFound error calling {func_name}: {err}" 69 | ) from err 70 | except Unauthenticated as err: 71 | _LOGGER.debug( 72 | "Failed to authenticate subscriber in %s: %s", func_name, err 73 | ) 74 | DIAGNOSTICS.increment(f"{func_name}.unauthenticated") 75 | raise AuthException( 76 | f"Failed to authenticate {func_name}: {err}" 77 | ) from err 78 | except GoogleAPIError as err: 79 | _LOGGER.debug("API error in %s: %s", func_name, err) 80 | DIAGNOSTICS.increment(f"{func_name}.api_error") 81 | raise SubscriberException( 82 | f"API error when calling {func_name}: {err}" 83 | ) from err 84 | except Exception as err: 85 | _LOGGER.debug("Uncaught error in %s: %s", func_name, err) 86 | DIAGNOSTICS.increment(f"{func_name}.api_error") 87 | raise SubscriberException( 88 | f"Unexpected error when calling {func_name}: {err}" 89 | ) from err 90 | 91 | return wrapped_func 92 | 93 | return wrapped 94 | 95 | 96 | async def pull_request_generator( 97 | subscription_name: str, 98 | ack_ids_generator: Callable[[], list[str]], 99 | ) -> AsyncGenerator[pubsub_v1.StreamingPullRequest, list[str]]: 100 | yield pubsub_v1.StreamingPullRequest( 101 | subscription=subscription_name, 102 | stream_ack_deadline_seconds=STREAM_ACK_TIMEOUT_SECONDS, 103 | ) 104 | while True: 105 | ids = ack_ids_generator() 106 | _LOGGER.debug("Sending streaming pull request (acking %s messages)", len(ids)) 107 | yield pubsub_v1.StreamingPullRequest( 108 | stream_ack_deadline_seconds=STREAM_ACK_TIMEOUT_SECONDS, 109 | ack_ids=ids, 110 | ) 111 | await asyncio.sleep(STREAM_ACK_FREQUENCY_SECONDS) 112 | 113 | 114 | async def aiter_exception_handler(iterable: AsyncIterable[_T]) -> AsyncIterable[_T]: 115 | """Wrap an async iterable with pub/sub exception handling.""" 116 | _LOGGER.debug("Starting streaming iterator") 117 | 118 | try: 119 | async for item in iterable: 120 | yield item 121 | except NotFound as err: 122 | _LOGGER.debug("NotFound error in streaming pull: %s", err) 123 | DIAGNOSTICS.increment("streaming_iterator.not_found_error") 124 | raise ConfigurationException( 125 | f"NotFound error calling streaming iterator: {err}" 126 | ) from err 127 | except Unauthenticated as err: 128 | _LOGGER.debug("Failed to authenticate subscriber in streaming pull: %s", err) 129 | DIAGNOSTICS.increment("streaming_iterator.unauthenticated") 130 | raise AuthException( 131 | f"Failed to authenticate in streaming iterator: {err}" 132 | ) from err 133 | except GoogleAPIError as err: 134 | _LOGGER.debug("API error in streaming pull: %s", err) 135 | DIAGNOSTICS.increment("streaming_iterator.api_error") 136 | raise SubscriberException(f"API error when streaming iterator: {err}") from err 137 | except Exception as err: 138 | _LOGGER.debug("Uncaught error in streaming pull: %s", err) 139 | DIAGNOSTICS.increment("streaming_iterator.api_error") 140 | raise SubscriberException( 141 | f"Unexpected error when streaming iterator: {err}" 142 | ) from err 143 | 144 | 145 | class SubscriberClient: 146 | """Pub/sub subscriber client library.""" 147 | 148 | def __init__( 149 | self, 150 | auth: AbstractAuth, 151 | subscription_name: str, 152 | ) -> None: 153 | """Initialize the SubscriberClient.""" 154 | self._auth = auth 155 | self._subscription_name = subscription_name 156 | self._client: pubsub_v1.SubscriberAsyncClient | None = None 157 | self._creds: Credentials | None = None 158 | 159 | async def _async_get_client(self) -> pubsub_v1.SubscriberAsyncClient: 160 | """Create the Pub/Sub client library.""" 161 | if self._client is None or self._creds is None or self._creds.expired: 162 | try: 163 | creds = await self._auth.async_get_creds() 164 | except ClientError as err: 165 | DIAGNOSTICS.increment("create_subscription.creds_error") 166 | raise AuthException(f"Access token failure: {err}") from err 167 | _LOGGER.debug("Credentials refreshed, new expiry %s", creds.expiry) 168 | self._creds = creds 169 | self._client = pubsub_v1.SubscriberAsyncClient(credentials=self._creds) 170 | return self._client 171 | 172 | @exception_handler("streaming_pull") 173 | async def streaming_pull( 174 | self, 175 | ack_ids_generator: Callable[[], list[str]], 176 | ) -> AsyncIterable[pubsub_v1.types.StreamingPullResponse]: 177 | """Start the streaming pull.""" 178 | client = await self._async_get_client() 179 | req_gen = pull_request_generator(self._subscription_name, ack_ids_generator) 180 | _LOGGER.debug("Sending streaming pull request for %s", self._subscription_name) 181 | try: 182 | async with asyncio.timeout(STREAMING_PULL_TIMEOUT_SECONDS): 183 | stream: AsyncIterable[pubsub_v1.types.StreamingPullResponse] = ( 184 | await client.streaming_pull(requests=req_gen) 185 | ) 186 | except asyncio.TimeoutError as err: 187 | _LOGGER.debug("Timeout in streaming_pull %s", err) 188 | DIAGNOSTICS.increment("streaming_pull.timeout") 189 | raise SubscriberException("Timeout in streaming_pull") from err 190 | _LOGGER.debug("Streaming pull started") 191 | return aiter_exception_handler(stream) 192 | 193 | @exception_handler("acknowledge") 194 | async def ack_messages(self, ack_ids: list[str]) -> None: 195 | """Acknowledge messages.""" 196 | if not ack_ids: 197 | return 198 | client = await self._async_get_client() 199 | _LOGGER.debug("Acking %s messages", len(ack_ids)) 200 | try: 201 | async with asyncio.timeout(RPC_TIMEOUT_SECONDS): 202 | await client.acknowledge( 203 | subscription=self._subscription_name, 204 | ack_ids=ack_ids, 205 | ) 206 | except asyncio.TimeoutError as err: 207 | _LOGGER.debug("Timeout in acknowledge: %s", err) 208 | DIAGNOSTICS.increment("acknowledge.timeout") 209 | raise SubscriberException("Timeout in acknowledge") from err 210 | -------------------------------------------------------------------------------- /google_nest_sdm/thermostat_traits.py: -------------------------------------------------------------------------------- 1 | """Traits for thermostats.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Final, ClassVar 7 | 8 | import aiohttp 9 | from mashumaro import field_options, DataClassDictMixin 10 | 11 | from .traits import CommandDataClass, TraitType 12 | 13 | __all__ = [ 14 | "ThermostatEcoTrait", 15 | "ThermostatHvacTrait", 16 | "ThermostatModeTrait", 17 | "ThermostatTemperatureSetpointTrait", 18 | ] 19 | 20 | STATUS: Final = "status" 21 | AVAILABLE_MODES: Final = "availableModes" 22 | MODE: Final = "mode" 23 | 24 | 25 | @dataclass 26 | class ThermostatEcoTrait(DataClassDictMixin, CommandDataClass): 27 | """This trait belongs to any device that has a sensor to measure temperature.""" 28 | 29 | NAME: ClassVar[TraitType] = TraitType.THERMOSTAT_ECO 30 | 31 | available_modes: list[str] = field( 32 | metadata=field_options(alias="availableModes"), default_factory=list 33 | ) 34 | """List of supported Eco modes.""" 35 | 36 | mode: str = field(metadata=field_options(alias="mode"), default="OFF") 37 | """Eco mode of the thermostat.""" 38 | 39 | heat_celsius: float | None = field( 40 | metadata=field_options(alias="heatCelsius"), default=None 41 | ) 42 | """Lowest temperature where thermostat begins heating.""" 43 | 44 | cool_celsius: float | None = field( 45 | metadata=field_options(alias="coolCelsius"), default=None 46 | ) 47 | """Highest cooling temperature where thermostat begins cooling.""" 48 | 49 | async def set_mode(self, mode: str) -> aiohttp.ClientResponse: 50 | """Change the thermostat Eco mode.""" 51 | data = { 52 | "command": "sdm.devices.commands.ThermostatEco.SetMode", 53 | "params": {"mode": mode}, 54 | } 55 | return await self.cmd.execute(data) 56 | 57 | 58 | @dataclass 59 | class ThermostatHvacTrait: 60 | """This trait belongs to devices that can report HVAC details.""" 61 | 62 | NAME: ClassVar[TraitType] = TraitType.THERMOSTAT_HVAC 63 | 64 | status: str 65 | """HVAC status of the thermostat.""" 66 | 67 | 68 | @dataclass 69 | class ThermostatModeTrait(DataClassDictMixin, CommandDataClass): 70 | """This trait belongs to devices that support different thermostat modes.""" 71 | 72 | NAME: ClassVar[TraitType] = TraitType.THERMOSTAT_MODE 73 | 74 | available_modes: list[str] = field(metadata=field_options(alias="availableModes")) 75 | """List of supported thermostat modes.""" 76 | 77 | mode: str = field(metadata=field_options(alias="mode")) 78 | """Mode of the thermostat.""" 79 | 80 | async def set_mode(self, mode: str) -> aiohttp.ClientResponse: 81 | """Change the thermostat Eco mode.""" 82 | data = { 83 | "command": "sdm.devices.commands.ThermostatMode.SetMode", 84 | "params": {"mode": mode}, 85 | } 86 | return await self.cmd.execute(data) 87 | 88 | 89 | @dataclass 90 | class ThermostatTemperatureSetpointTrait(DataClassDictMixin, CommandDataClass): 91 | """This trait belongs to devices that support setting target temperature.""" 92 | 93 | NAME: ClassVar[TraitType] = TraitType.THERMOSTAT_TEMPERATURE_SETPOINT 94 | 95 | heat_celsius: float | None = field( 96 | metadata=field_options(alias="heatCelsius"), default=None 97 | ) 98 | """Lowest temperature where thermostat begins heating.""" 99 | 100 | cool_celsius: float | None = field( 101 | metadata=field_options(alias="coolCelsius"), default=None 102 | ) 103 | """Highest cooling temperature where thermostat begins cooling.""" 104 | 105 | async def set_heat(self, heat: float) -> aiohttp.ClientResponse: 106 | """Change the thermostat Eco mode.""" 107 | data = { 108 | "command": "sdm.devices.commands.ThermostatTemperatureSetpoint.SetHeat", 109 | "params": {"heatCelsius": heat}, 110 | } 111 | return await self.cmd.execute(data) 112 | 113 | async def set_cool(self, cool: float) -> aiohttp.ClientResponse: 114 | """Change the thermostat Eco mode.""" 115 | data = { 116 | "command": "sdm.devices.commands.ThermostatTemperatureSetpoint.SetCool", 117 | "params": {"coolCelsius": cool}, 118 | } 119 | return await self.cmd.execute(data) 120 | 121 | async def set_range(self, heat: float, cool: float) -> aiohttp.ClientResponse: 122 | """Change the thermostat Eco mode.""" 123 | data = { 124 | "command": "sdm.devices.commands.ThermostatTemperatureSetpoint.SetRange", 125 | "params": { 126 | "heatCelsius": heat, 127 | "coolCelsius": cool, 128 | }, 129 | } 130 | return await self.cmd.execute(data) 131 | -------------------------------------------------------------------------------- /google_nest_sdm/traits.py: -------------------------------------------------------------------------------- 1 | """Base library for all traits.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC 6 | from enum import StrEnum 7 | from typing import Any, Mapping 8 | 9 | import aiohttp 10 | from mashumaro.types import SerializableType 11 | 12 | from .auth import AbstractAuth 13 | from .diagnostics import Diagnostics 14 | 15 | DEVICE_TRAITS = "traits" 16 | TRAITS = "traits" 17 | 18 | 19 | class TraitType(StrEnum): 20 | """Traits for SDM devices.""" 21 | 22 | CAMERA_IMAGE = "sdm.devices.traits.CameraImage" 23 | CAMERA_LIVE_STREAM = "sdm.devices.traits.CameraLiveStream" 24 | CAMERA_EVENT_IMAGE = "sdm.devices.traits.CameraEventImage" 25 | CAMERA_MOTION = "sdm.devices.traits.CameraMotion" 26 | CAMERA_PERSON = "sdm.devices.traits.CameraPerson" 27 | CAMERA_SOUND = "sdm.devices.traits.CameraSound" 28 | CAMERA_CLIP_PREVIEW = "sdm.devices.traits.CameraClipPreview" 29 | CONNECTIVITY = "sdm.devices.traits.Connectivity" 30 | FAN = "sdm.devices.traits.Fan" 31 | INFO = "sdm.devices.traits.Info" 32 | HUMIDITY = "sdm.devices.traits.Humidity" 33 | TEMPERATURE = "sdm.devices.traits.Temperature" 34 | DOORBELL_CHIME = "sdm.devices.traits.DoorbellChime" 35 | THERMOSTAT_ECO = "sdm.devices.traits.ThermostatEco" 36 | THERMOSTAT_HVAC = "sdm.devices.traits.ThermostatHvac" 37 | THERMOSTAT_MODE = "sdm.devices.traits.ThermostatMode" 38 | THERMOSTAT_TEMPERATURE_SETPOINT = "sdm.devices.traits.ThermostatTemperatureSetpoint" 39 | 40 | 41 | class Command(SerializableType): 42 | """Base class for executing commands.""" 43 | 44 | def __init__(self, device_id: str, auth: AbstractAuth, diagnostics: Diagnostics): 45 | """Initialize Command.""" 46 | self._device_id = device_id 47 | self._auth = auth 48 | self._diagnostics = diagnostics 49 | 50 | async def execute(self, data: Mapping[str, Any]) -> aiohttp.ClientResponse: 51 | """Run the command.""" 52 | assert self._auth 53 | cmd = data.get("command", "execute") 54 | with self._diagnostics.timer(cmd): 55 | return await self._auth.post(f"{self._device_id}:executeCommand", json=data) 56 | 57 | async def execute_json(self, data: Mapping[str, Any]) -> dict[str, Any]: 58 | """Run the command and return a json result.""" 59 | assert self._auth 60 | cmd = data.get("command", "execute") 61 | with self._diagnostics.timer(cmd): 62 | return await self._auth.post_json( 63 | f"{self._device_id}:executeCommand", json=data 64 | ) 65 | 66 | async def fetch_image(self, url: str, basic_auth: str | None = None) -> bytes: 67 | """Fetch an image at the specified url.""" 68 | headers: dict[str, Any] = {} 69 | if basic_auth: 70 | headers = {"Authorization": f"Basic {basic_auth}"} 71 | with self._diagnostics.timer("fetch_image"): 72 | resp = await self._auth.get(url, headers=headers) 73 | return await resp.read() 74 | 75 | 76 | class CommandDataClass(ABC): 77 | """Base model that supports commands.""" 78 | 79 | def __post_init__(self) -> None: 80 | self._cmd: Command | None = None 81 | 82 | @property 83 | def cmd(self) -> Command: 84 | """Helper for executing commands, used internally by the trait""" 85 | if not self._cmd: 86 | raise ValueError("Device trait in invalid state") 87 | return self._cmd 88 | -------------------------------------------------------------------------------- /google_nest_sdm/transcoder.py: -------------------------------------------------------------------------------- 1 | """Library for transcoding mp4 clips.""" 2 | 3 | import asyncio.subprocess 4 | import logging 5 | import os 6 | 7 | from .exceptions import TranscodeException 8 | 9 | _LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | class Transcoder: 13 | """A worker that processes mp4 clips.""" 14 | 15 | def __init__(self, ffmpeg_binary: str, path_prefix: str) -> None: 16 | """Initialize transcoder.""" 17 | self._ffmpeg_binary = ffmpeg_binary 18 | self._path_prefix = path_prefix 19 | 20 | async def transcode_clip(self, input_file: str, output_file: str) -> None: 21 | """Create a image preview for a thumbnail clip.""" 22 | full_input_file = f"{self._path_prefix}/{input_file}" 23 | full_output_file = f"{self._path_prefix}/{output_file}" 24 | if not os.path.exists(full_input_file): 25 | raise TranscodeException(f"Input file does not exist: {full_input_file}") 26 | if os.path.exists(full_output_file): 27 | raise TranscodeException(f"Output file already exists: {full_output_file}") 28 | cmd = " ".join( 29 | [ 30 | self._ffmpeg_binary, 31 | "-y", 32 | "-i", 33 | full_input_file, 34 | "-vf setpts=2.0*PTS", 35 | "-vf scale=320:-1,setsar=1:1", 36 | "-r 4", 37 | "-loop 0", 38 | full_output_file, 39 | ] 40 | ) 41 | proc = await asyncio.create_subprocess_shell(cmd) 42 | stdout, stderr = await proc.communicate() 43 | if proc.returncode != 0: 44 | if stdout: 45 | _LOGGER.debug(stdout) 46 | if stderr: 47 | _LOGGER.debug(stderr) 48 | raise TranscodeException( 49 | f"Transcode command failure: {cmd} code: {proc.returncode}" 50 | ) -------------------------------------------------------------------------------- /google_nest_sdm/webrtc_util.py: -------------------------------------------------------------------------------- 1 | """Library with functions for manipulating WebRTC requests/responses.""" 2 | 3 | from enum import StrEnum 4 | 5 | 6 | class SDPDirection(StrEnum): 7 | """SDP direction constants.""" 8 | 9 | SENDRECV = "sendrecv" 10 | SENDONLY = "sendonly" 11 | RECVONLY = "recvonly" 12 | INACTIVE = "inactive" 13 | 14 | 15 | class SDPMediaKind(StrEnum): 16 | """SDP media kind constants.""" 17 | 18 | AUDIO = "audio" 19 | VIDEO = "video" 20 | APPLICATION = "application" 21 | 22 | 23 | def _get_media_direction(sdp: str, kind: SDPMediaKind) -> SDPDirection | None: 24 | """Retrieves the direction of media tracks from the SDP based on the kind (audio/video).""" 25 | 26 | # Track if we are in the desired media section 27 | in_media_section = False 28 | 29 | for line in sdp.split("\r\n"): 30 | # Check if the line is a media description line 31 | if line.startswith("m="): 32 | in_media_section = line.startswith(f"m={kind}") 33 | # If we're in the desired media section, check for direction 34 | if in_media_section and line.startswith("a="): 35 | for direction in SDPDirection: 36 | if line.startswith(f"a={direction}"): 37 | return direction 38 | return None 39 | 40 | 41 | def _update_direction_in_answer( 42 | answer_sdp: str, 43 | kind: SDPMediaKind, 44 | old_direction: SDPDirection, 45 | new_direction: SDPDirection, 46 | ) -> str: 47 | """Updates the direction of a specific media track in the SDP answer if it matches a certain direction.""" 48 | 49 | # Update the SDP 50 | updated_sdp_lines = [] 51 | in_media_section = False 52 | for line in answer_sdp.split("\r\n"): 53 | if line.startswith("m="): 54 | in_media_section = line.startswith(f"m={kind}") 55 | if in_media_section and line.startswith("a="): 56 | # Update the direction line if it matches the kind 57 | if line.startswith(f"a={old_direction}"): 58 | updated_sdp_lines.append( 59 | line.replace(f"a={old_direction}", f"a={new_direction}") 60 | ) 61 | continue 62 | updated_sdp_lines.append(line) 63 | return "\r\n".join(updated_sdp_lines) 64 | 65 | 66 | def _add_foundation_to_candidates(sdp: str) -> str: 67 | """Adds a foundation value to all ICE candidates in the SDP if it does not already exist.""" 68 | 69 | updated_sdp_lines = [] 70 | index = 1 71 | for line in sdp.split("\r\n"): 72 | if line.startswith("a=candidate: "): 73 | updated_sdp_lines.append( 74 | line.replace("a=candidate: ", f"a=candidate:{index} ") 75 | ) 76 | index += 1 77 | continue 78 | updated_sdp_lines.append(line) 79 | return "\r\n".join(updated_sdp_lines) 80 | 81 | 82 | def fix_mozilla_sdp_answer(offer_sdp: str, answer_sdp: str) -> str: 83 | """Fix the answer SDP which is rejected by Firefox. 84 | 85 | 1. If offer SDP is recvonly, the direction of answer SDP must not be sendrecv. 86 | 2. If the ICE candidates in answer SDP must contain "foundation" field. 87 | """ 88 | if "mozilla" in offer_sdp: 89 | if ( 90 | _get_media_direction(sdp=offer_sdp, kind=SDPMediaKind.VIDEO) 91 | == SDPDirection.RECVONLY 92 | ): 93 | answer_sdp = _update_direction_in_answer( 94 | answer_sdp=answer_sdp, 95 | kind=SDPMediaKind.VIDEO, 96 | old_direction=SDPDirection.SENDRECV, 97 | new_direction=SDPDirection.SENDONLY, 98 | ) 99 | if ( 100 | _get_media_direction(sdp=offer_sdp, kind=SDPMediaKind.AUDIO) 101 | == SDPDirection.RECVONLY 102 | ): 103 | answer_sdp = _update_direction_in_answer( 104 | answer_sdp=answer_sdp, 105 | kind=SDPMediaKind.AUDIO, 106 | old_direction=SDPDirection.SENDRECV, 107 | new_direction=SDPDirection.SENDONLY, 108 | ) 109 | return _add_foundation_to_candidates(answer_sdp) 110 | return answer_sdp 111 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | follow_imports = silent 3 | ignore_missing_imports = True 4 | exclude = (venv|build) 5 | check_untyped_defs = True 6 | disallow_incomplete_defs = True 7 | disallow_subclassing_any = True 8 | disallow_untyped_calls = True 9 | disallow_untyped_decorators = True 10 | disallow_untyped_defs = True 11 | no_implicit_optional = True 12 | warn_return_any = True 13 | warn_unreachable = True 14 | warn_redundant_casts = True 15 | warn_no_return = True 16 | show_error_codes = True 17 | # For temporarily dealing with mypy updates (#151 and #152) 18 | warn_unused_ignores = False 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.mypy] 2 | exclude = [ 3 | "setup.py", 4 | "venv/", 5 | ] 6 | platform = "linux" 7 | show_error_codes = true 8 | follow_imports = "normal" 9 | local_partial_types = true 10 | strict_equality = true 11 | no_implicit_optional = true 12 | warn_incomplete_stub = true 13 | warn_redundant_casts = true 14 | warn_unused_configs = true 15 | warn_unused_ignores = true 16 | disable_error_code = [ 17 | "import-untyped", 18 | ] 19 | extra_checks = false 20 | check_untyped_defs = true 21 | disallow_incomplete_defs = true 22 | disallow_subclassing_any = true 23 | disallow_untyped_calls = true 24 | disallow_untyped_decorators = true 25 | disallow_untyped_defs = true 26 | warn_return_any = true 27 | warn_unreachable = true 28 | 29 | ignore_missing_imports = true 30 | warn_no_return = true 31 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode = auto 3 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | black==25.1.0 3 | coverage==7.8.2 4 | mypy==1.16.0 5 | pdoc==15.0.4 6 | pip==25.1.1 7 | pre-commit==4.2.0 8 | pytest-cov==6.1.1 9 | pytest==8.4.0 10 | ruff==0.11.13 11 | 12 | 13 | aiohttp==3.12.9 14 | async-timeout==5.0.1 15 | google-api-core==2.25.0 16 | google-auth==2.29.0 17 | google-cloud-pubsub==2.29.1 18 | googleapis-common-protos==1.70.0 19 | grpcio==1.72.1 20 | grpcio-status==1.71.0 21 | protobuf==5.29.5 22 | pytest-aiohttp==1.1.0 23 | pytest-asyncio==1.0.0 24 | pytest-mock==3.14.1 25 | PyYAML==6.0.2 26 | types-futures==3.3.8 27 | types-protobuf==5.29.1.20250403 28 | types-PyYAML==6.0.12.20250516 29 | typing-extensions==4.14.0 30 | typing-inspect==0.9.0 31 | urllib3==2.4.0 32 | mashumaro==3.16 33 | -------------------------------------------------------------------------------- /script/run-mypy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -o errexit 4 | 5 | # other common virtualenvs 6 | my_path=$(git rev-parse --show-toplevel) 7 | 8 | for venv in venv .venv .; do 9 | if [ -f "${my_path}/${venv}/bin/activate" ]; then 10 | . "${my_path}/${venv}/bin/activate" 11 | break 12 | fi 13 | done 14 | 15 | mypy ${my_path} 16 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = google_nest_sdm 3 | version = 7.1.5 4 | description = Library for the Google Nest SDM API 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | url = https://github.com/allenporter/python-google-nest-sdm 8 | author = Allen Porter 9 | author_email = allen.porter@gmail.com 10 | license = Apache-2.0 11 | license_file = LICENSE 12 | classifiers = 13 | License :: OSI Approved :: Apache Software License 14 | 15 | [options] 16 | packages = find: 17 | python_requires = >=3.11 18 | install_requires = 19 | aiohttp>=3.7.3 20 | google-auth>=1.22.0 21 | google-auth-oauthlib>=0.4.1 22 | google-cloud-pubsub>=2.1.0 23 | requests-oauthlib>=1.3.0 24 | PyYAML>=6.0 25 | mashumaro>=3.12 26 | include_package_data = True 27 | package_dir = 28 | = . 29 | 30 | [options.packages.find] 31 | where = . 32 | exclude = 33 | tests 34 | tests.* 35 | 36 | [options.package_data] 37 | google_nest_sdm = py.typed 38 | 39 | [options.entry_points] 40 | console_scripts = 41 | google_nest=google_nest_sdm.google_nest:main 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Library for packaging the project.""" 2 | 3 | from setuptools import setup 4 | 5 | setup() 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenporter/python-google-nest-sdm/0ed9c93b1a77cabca578b6ff73c2e313e1bd43ef/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures and libraries shared by tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import uuid 6 | from http import HTTPStatus 7 | from abc import ABC, abstractmethod 8 | from typing import ( 9 | Any, 10 | Awaitable, 11 | Callable, 12 | Dict, 13 | Generator, 14 | Optional, 15 | cast, 16 | ) 17 | import logging 18 | 19 | import aiohttp 20 | import pytest 21 | from aiohttp.test_utils import TestClient, TestServer 22 | 23 | from google_nest_sdm import diagnostics, google_nest_api 24 | from google_nest_sdm.auth import AbstractAuth 25 | from google_nest_sdm.device import Device 26 | from google_nest_sdm.event import EventMessage 27 | 28 | FAKE_TOKEN = "some-token" 29 | PROJECT_ID = "project-id1" 30 | 31 | _LOGGER = logging.getLogger(__name__) 32 | 33 | 34 | def pytest_configure(config: pytest.Config) -> None: 35 | """Register marker for tests that log exceptions.""" 36 | logging.basicConfig( 37 | level=logging.INFO, 38 | format="%(asctime)s.%(msecs)03d %(levelname)-8s %(name)s:%(filename)s:%(lineno)s %(message)s", # noqa: E501 39 | datefmt="%Y-%m-%d %H:%M:%S", 40 | ) 41 | if config.getoption("verbose") > 0: 42 | logging.getLogger().setLevel(logging.DEBUG) 43 | 44 | 45 | @pytest.fixture(name="app") 46 | def mock_app() -> Generator[aiohttp.web.Application, None, None]: 47 | yield aiohttp.web.Application() 48 | 49 | 50 | @pytest.fixture(name="server") 51 | def mock_server( 52 | app: aiohttp.web.Application, 53 | aiohttp_server: Callable[[aiohttp.web.Application], Awaitable[TestServer]], 54 | ) -> Callable[[], Awaitable[TestServer]]: 55 | async def _make_server() -> TestServer: 56 | server = await aiohttp_server(app) 57 | server.skip_url_asserts = True 58 | assert isinstance(server, TestServer) 59 | return server 60 | 61 | return _make_server 62 | 63 | 64 | @pytest.fixture(name="client") 65 | def mock_client( 66 | server: Callable[[], Awaitable[TestServer]], 67 | aiohttp_client: Callable[[TestServer], Awaitable[TestClient]], 68 | ) -> Callable[[], Awaitable[TestClient]]: 69 | # Cache the value so that it can be mutated by a test 70 | cached_client: Optional[TestClient] = None 71 | 72 | async def _make_client() -> TestClient: 73 | nonlocal cached_client 74 | if not cached_client: 75 | cached_client = await aiohttp_client(await server()) 76 | assert isinstance(cached_client, TestClient) 77 | return cached_client 78 | 79 | return _make_client 80 | 81 | 82 | @pytest.fixture(name="api_client") 83 | def mock_api_client( 84 | project_id: str, 85 | auth_client: Callable[[], Awaitable[AbstractAuth]], 86 | ) -> Callable[[], Awaitable[google_nest_api.GoogleNestAPI]]: 87 | async def make_api() -> google_nest_api.GoogleNestAPI: 88 | auth = await auth_client() 89 | return google_nest_api.GoogleNestAPI(auth, project_id) 90 | 91 | return make_api 92 | 93 | 94 | class FakeAuth(AbstractAuth): 95 | def __init__(self, test_client: TestClient, path_prefix: str = "") -> None: 96 | super().__init__(cast(aiohttp.ClientSession, test_client), path_prefix) 97 | 98 | async def async_get_access_token(self) -> str: 99 | return FAKE_TOKEN 100 | 101 | 102 | @pytest.fixture(name="auth_client") 103 | def mock_auth_client( 104 | app: aiohttp.web.Application, client: Any 105 | ) -> Callable[[str], Awaitable[AbstractAuth]]: 106 | async def _make_auth(path_prefix: str = "") -> AbstractAuth: 107 | return FakeAuth(await client(), path_prefix) 108 | 109 | return _make_auth 110 | 111 | 112 | class RefreshingAuth(AbstractAuth): 113 | def __init__(self, test_client: TestClient) -> None: 114 | super().__init__(cast(aiohttp.ClientSession, test_client), "") 115 | 116 | async def async_get_access_token(self) -> str: 117 | resp = await self._websession.request("get", "/refresh-auth") 118 | resp.raise_for_status() 119 | json = await resp.json() 120 | assert isinstance(json["token"], str) 121 | return json["token"] 122 | 123 | 124 | @pytest.fixture 125 | async def refreshing_auth_client( 126 | app: aiohttp.web.Application, client: Any 127 | ) -> Callable[[], Awaitable[AbstractAuth]]: 128 | async def _make_auth() -> AbstractAuth: 129 | return RefreshingAuth(await client()) 130 | 131 | return _make_auth 132 | 133 | 134 | @pytest.fixture 135 | def event_message( 136 | app: aiohttp.web.Application, auth_client: Callable[[], Awaitable[AbstractAuth]] 137 | ) -> Callable[[Dict[str, Any]], Awaitable[EventMessage]]: 138 | async def _make_event(raw_data: Dict[str, Any]) -> EventMessage: 139 | return EventMessage.create_event(raw_data, await auth_client()) 140 | 141 | return _make_event 142 | 143 | 144 | @pytest.fixture 145 | def fake_event_message() -> Callable[[Dict[str, Any]], EventMessage]: 146 | def _make_event(raw_data: Dict[str, Any]) -> EventMessage: 147 | return EventMessage.create_event(raw_data, cast(AbstractAuth, None)) 148 | 149 | return _make_event 150 | 151 | 152 | @pytest.fixture 153 | def fake_device() -> Callable[[Dict[str, Any]], Device]: 154 | def _make_device(raw_data: Dict[str, Any]) -> Device: 155 | return Device.MakeDevice(raw_data, cast(AbstractAuth, None)) 156 | 157 | return _make_device 158 | 159 | 160 | class Recorder: 161 | request: Optional[Dict[str, Any]] = None 162 | 163 | 164 | class JsonHandler(ABC): 165 | """Request handler that replays mocks.""" 166 | 167 | def __init__(self, recorder: Recorder) -> None: 168 | """Initialize Handler.""" 169 | self.token: str = FAKE_TOKEN 170 | self.recorder = recorder 171 | 172 | @abstractmethod 173 | def get_response(self) -> dict[str, Any]: 174 | """Implemented by subclasses to return a response.""" 175 | 176 | async def handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: 177 | _LOGGER.debug("Request: %s", request) 178 | assert request.headers["Authorization"] == "Bearer %s" % self.token 179 | s = await request.text() 180 | self.recorder.request = await request.json() if s else {} 181 | return aiohttp.web.json_response(self.get_response()) 182 | 183 | 184 | class ReplyHandler(JsonHandler): 185 | def __init__( 186 | self, recorder: Recorder, responses: list[dict[str, Any]] | None = None 187 | ) -> None: 188 | """Initialize ReplyHandler.""" 189 | super().__init__(recorder) 190 | self.responses = responses or [] 191 | 192 | def get_response(self) -> dict[str, Any]: 193 | """Return an API response.""" 194 | return self.responses.pop(0) 195 | 196 | 197 | def reply_handler( 198 | app: aiohttp.web.Application, 199 | path: str, 200 | recorder: Recorder, 201 | responses: list[dict[str, Any]], 202 | ) -> ReplyHandler: 203 | """Create a new reply handler.""" 204 | handler = ReplyHandler(recorder, responses) 205 | app.router.add_get(path, handler.handler) 206 | return handler 207 | 208 | 209 | class DeviceHandler(JsonHandler): 210 | """Handle requests to fetch devices.""" 211 | 212 | def __init__( 213 | self, app: aiohttp.web.Application, project_id: str, recorder: Recorder 214 | ) -> None: 215 | """Initialize DeviceHandler.""" 216 | super().__init__(recorder) 217 | self.app = app 218 | self.project_id = project_id 219 | self.devices: list[dict[str, Any]] = [] 220 | app.router.add_get(f"/enterprises/{project_id}/devices", self.handler) 221 | 222 | def add_device( 223 | self, 224 | device_type: str = "sdm.devices.types.device-type1", 225 | traits: dict[str, Any] = {}, 226 | parentRelations: list[dict[str, Any]] = [], 227 | ) -> str: 228 | """Add a fake device reply.""" 229 | uid = uuid.uuid4().hex 230 | device_id = f"enterprises/{self.project_id}/devices/device-id-{uid}" 231 | device = { 232 | "name": device_id, 233 | "type": device_type, 234 | "traits": traits, 235 | "parentRelations": parentRelations, 236 | } 237 | # Setup device lookup reply 238 | reply_handler(self.app, f"/{device_id}", self.recorder, [device]) 239 | # Setup device list reply 240 | self.devices.append(device) 241 | return device_id 242 | 243 | def get_response(self) -> dict[str, Any]: 244 | """Return devices API response.""" 245 | return {"devices": self.devices} 246 | 247 | 248 | class StructureHandler(JsonHandler): 249 | """Handle requests to fetch structures.""" 250 | 251 | def __init__( 252 | self, app: aiohttp.web.Application, project_id: str, recorder: Recorder 253 | ) -> None: 254 | """Initialize StructureHandler.""" 255 | super().__init__(recorder) 256 | self.app = app 257 | self.project_id = project_id 258 | self.structures: list[dict[str, Any]] = [] 259 | app.router.add_get(f"/enterprises/{project_id}/structures", self.handler) 260 | 261 | def add_structure(self, traits: dict[str, Any] = {}) -> str: 262 | """Add a structure to the response.""" 263 | uid = uuid.uuid4().hex 264 | structure_id = f"enterprises/{self.project_id}/structures/structure-id-{uid}" 265 | structure = { 266 | "name": structure_id, 267 | "traits": traits, 268 | } 269 | # Setup structure lookup reply 270 | reply_handler(self.app, f"/{structure_id}", self.recorder, [structure]) 271 | # Setup structure list reply 272 | self.structures.append(structure) 273 | return structure_id 274 | 275 | def get_response(self) -> dict[str, Any]: 276 | """Return structure API response.""" 277 | return {"structures": self.structures} 278 | 279 | 280 | def NewHandler( 281 | r: Recorder, 282 | responses: list[dict[str, Any]], 283 | token: str = FAKE_TOKEN, 284 | status: HTTPStatus = HTTPStatus.OK, 285 | ) -> Callable[[aiohttp.web.Request], Awaitable[aiohttp.web.Response]]: 286 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 287 | assert request.headers["Authorization"] == "Bearer %s" % token 288 | s = await request.text() 289 | r.request = await request.json() if s else {} 290 | return aiohttp.web.json_response(responses.pop(0), status=status) 291 | 292 | return handler 293 | 294 | 295 | def NewImageHandler( 296 | response: list[bytes], token: str = FAKE_TOKEN 297 | ) -> Callable[[aiohttp.web.Request], Awaitable[aiohttp.web.Response]]: 298 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 299 | assert request.headers["Authorization"] == "Basic %s" % token 300 | return aiohttp.web.Response(body=response.pop(0)) 301 | 302 | return handler 303 | 304 | 305 | @pytest.fixture 306 | def project_id() -> str: 307 | return "project-id1" 308 | 309 | 310 | @pytest.fixture 311 | def recorder() -> Recorder: 312 | return Recorder() 313 | 314 | 315 | @pytest.fixture(name="device_handler") 316 | def mock_device_handler( 317 | app: aiohttp.web.Application, project_id: str, recorder: Recorder 318 | ) -> DeviceHandler: 319 | return DeviceHandler(app, project_id, recorder) 320 | 321 | 322 | @pytest.fixture(name="structure_handler") 323 | def mock_structure_handler( 324 | app: aiohttp.web.Application, project_id: str, recorder: Recorder 325 | ) -> StructureHandler: 326 | return StructureHandler(app, project_id, recorder) 327 | 328 | 329 | @pytest.fixture(autouse=True) 330 | def reset_diagnostics() -> Generator[None, None, None]: 331 | yield 332 | diagnostics.reset() 333 | 334 | 335 | def assert_diagnostics(actual: dict[str, Any], expected: dict[str, Any]) -> None: 336 | """Helper method for stripping timing based daignostics.""" 337 | 338 | def scrub_dict(data: dict[str, Any]) -> dict[str, Any]: 339 | drop_keys = [] 340 | for k1, v1 in data.items(): 341 | if k1.endswith("_sum"): 342 | drop_keys.append(k1) 343 | for k in drop_keys: 344 | del data[k] 345 | return data 346 | 347 | actual = scrub_dict(actual) 348 | 349 | for k1, v1 in actual.items(): 350 | if isinstance(v1, dict): 351 | actual[k1] = scrub_dict(v1) 352 | 353 | assert actual == expected 354 | 355 | 356 | class EventCallback: 357 | """A callback that can be used in tests for assertions.""" 358 | 359 | def __init__(self) -> None: 360 | """Initialize EventCallback.""" 361 | self.invoked: bool = False 362 | self.messages: list[EventMessage] = [] 363 | 364 | async def async_handle_event(self, event_message: EventMessage) -> None: 365 | self.invoked = True 366 | self.messages.append(event_message) 367 | -------------------------------------------------------------------------------- /tests/test_auth.py: -------------------------------------------------------------------------------- 1 | """Tests for the request client library.""" 2 | 3 | from typing import Awaitable, Callable 4 | 5 | import aiohttp 6 | import pytest 7 | from aiohttp.test_utils import TestClient, TestServer 8 | from yarl import URL 9 | 10 | from google_nest_sdm.auth import AbstractAuth 11 | from google_nest_sdm.exceptions import ApiException 12 | 13 | 14 | async def test_request( 15 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 16 | ) -> None: 17 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 18 | assert request.path == "/path-prefix/some-path" 19 | assert request.headers["header-1"] == "value-1" 20 | assert request.headers["Authorization"] == "Bearer some-token" 21 | assert request.query == {"client_id": "some-client-id"} 22 | return aiohttp.web.json_response( 23 | { 24 | "some-key": "some-value", 25 | } 26 | ) 27 | 28 | app.router.add_get("/path-prefix/some-path", handler) 29 | 30 | auth = await auth_client("/path-prefix") 31 | resp = await auth.request( 32 | "get", 33 | "some-path", 34 | headers={"header-1": "value-1"}, 35 | params={"client_id": "some-client-id"}, 36 | ) 37 | resp.raise_for_status() 38 | data = await resp.json() 39 | assert data == {"some-key": "some-value"} 40 | 41 | 42 | async def test_auth_header( 43 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 44 | ) -> None: 45 | """Test that a request with an Ahthorization header is preserved.""" 46 | 47 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 48 | assert request.path == "/path-prefix/some-path" 49 | assert request.headers["header-1"] == "value-1" 50 | assert request.headers["Authorization"] == "Basic other-token" 51 | assert request.query["client_id"] == "some-client-id" 52 | return aiohttp.web.json_response( 53 | { 54 | "some-key": "some-value", 55 | } 56 | ) 57 | 58 | app.router.add_get("/path-prefix/some-path", handler) 59 | 60 | auth = await auth_client("/path-prefix") 61 | resp = await auth.request( 62 | "get", 63 | "some-path", 64 | headers={"header-1": "value-1", "Authorization": "Basic other-token"}, 65 | params={"client_id": "some-client-id"}, 66 | ) 67 | resp.raise_for_status() 68 | data = await resp.json() 69 | assert data == {"some-key": "some-value"} 70 | 71 | 72 | async def test_full_url( 73 | app: aiohttp.web.Application, 74 | client: Callable[[], Awaitable[TestClient]], 75 | server: Callable[[], Awaitable[TestServer]], 76 | auth_client: Callable[[str], Awaitable[AbstractAuth]], 77 | ) -> None: 78 | """Test that a request with an Ahthorization header is preserved.""" 79 | 80 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 81 | assert request.path == "/path-prefix/some-path" 82 | assert request.headers["header-1"] == "value-1" 83 | assert request.headers["Authorization"] == "Bearer some-token" 84 | assert request.query["client_id"] == "some-client-id" 85 | return aiohttp.web.json_response( 86 | { 87 | "some-key": "some-value", 88 | } 89 | ) 90 | 91 | app.router.add_get("/path-prefix/some-path", handler) 92 | 93 | test_server = await server() 94 | 95 | def client_make_url(url: str) -> URL: 96 | assert url == "https://example/path-prefix/some-path" 97 | return test_server.make_url("/path-prefix/some-path") 98 | 99 | test_client = await client() 100 | test_client.make_url = client_make_url # type: ignore 101 | 102 | auth = await auth_client("/path-prefix") 103 | resp = await auth.request( 104 | "get", 105 | "https://example/path-prefix/some-path", 106 | headers={"header-1": "value-1"}, 107 | params={"client_id": "some-client-id"}, 108 | ) 109 | resp.raise_for_status() 110 | data = await resp.json() 111 | assert data == {"some-key": "some-value"} 112 | 113 | 114 | async def test_get_json_response( 115 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 116 | ) -> None: 117 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 118 | assert request.query["client_id"] == "some-client-id" 119 | return aiohttp.web.json_response( 120 | { 121 | "some-key": "some-value", 122 | } 123 | ) 124 | 125 | app.router.add_get("/path-prefix/some-path", handler) 126 | 127 | auth = await auth_client("/path-prefix") 128 | data = await auth.get_json("some-path", params={"client_id": "some-client-id"}) 129 | assert data == {"some-key": "some-value"} 130 | 131 | 132 | async def test_get_json_response_unexpected( 133 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 134 | ) -> None: 135 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 136 | return aiohttp.web.json_response(["value1", "value2"]) 137 | 138 | app.router.add_get("/path-prefix/some-path", handler) 139 | 140 | auth = await auth_client("/path-prefix") 141 | with pytest.raises(ApiException): 142 | await auth.get_json("some-path") 143 | 144 | 145 | async def test_get_json_response_unexpected_text( 146 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 147 | ) -> None: 148 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 149 | return aiohttp.web.Response(text="body") 150 | 151 | app.router.add_get("/path-prefix/some-path", handler) 152 | 153 | auth = await auth_client("/path-prefix") 154 | with pytest.raises(ApiException): 155 | await auth.get_json("some-path") 156 | 157 | 158 | async def test_post_json_response( 159 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 160 | ) -> None: 161 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 162 | body = await request.json() 163 | assert body == {"client_id": "some-client-id"} 164 | return aiohttp.web.json_response( 165 | { 166 | "some-key": "some-value", 167 | } 168 | ) 169 | 170 | app.router.add_post("/path-prefix/some-path", handler) 171 | 172 | auth = await auth_client("/path-prefix") 173 | data = await auth.post_json("some-path", json={"client_id": "some-client-id"}) 174 | assert data == {"some-key": "some-value"} 175 | 176 | 177 | async def test_post_json_response_unexpected( 178 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 179 | ) -> None: 180 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 181 | return aiohttp.web.json_response(["value1", "value2"]) 182 | 183 | app.router.add_post("/path-prefix/some-path", handler) 184 | 185 | auth = await auth_client("/path-prefix") 186 | with pytest.raises(ApiException): 187 | await auth.post_json("some-path") 188 | 189 | 190 | async def test_post_json_response_unexpected_text( 191 | app: aiohttp.web.Application, auth_client: Callable[[str], Awaitable[AbstractAuth]] 192 | ) -> None: 193 | async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: 194 | return aiohttp.web.Response(text="body") 195 | 196 | app.router.add_post("/path-prefix/some-path", handler) 197 | 198 | auth = await auth_client("/path-prefix") 199 | with pytest.raises(ApiException): 200 | await auth.post_json("some-path") 201 | -------------------------------------------------------------------------------- /tests/test_device.py: -------------------------------------------------------------------------------- 1 | """Tests for device properties.""" 2 | 3 | from typing import Any, Callable, Dict 4 | 5 | import pytest 6 | 7 | from google_nest_sdm.device import Device 8 | 9 | from .conftest import assert_diagnostics 10 | 11 | 12 | def test_device_id(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 13 | device = fake_device( 14 | { 15 | "name": "my/device/name", 16 | "type": "sdm.devices.types.SomeDeviceType", 17 | } 18 | ) 19 | assert "my/device/name" == device.name 20 | assert "sdm.devices.types.SomeDeviceType" == device.type 21 | 22 | 23 | def test_no_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 24 | device = fake_device( 25 | { 26 | "name": "my/device/name", 27 | } 28 | ) 29 | assert "my/device/name" == device.name 30 | assert "sdm.devices.traits.Info" not in device.traits 31 | 32 | 33 | def test_empty_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 34 | device = fake_device( 35 | { 36 | "name": "my/device/name", 37 | "traits": {}, 38 | } 39 | ) 40 | assert "my/device/name" == device.name 41 | assert "sdm.devices.traits.Info" not in device.traits 42 | 43 | 44 | def test_no_name(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 45 | with pytest.raises(ValueError, match="'name' is required"): 46 | fake_device( 47 | { 48 | "traits": {}, 49 | } 50 | ) 51 | 52 | 53 | def test_no_parent_relations(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 54 | device = fake_device( 55 | { 56 | "name": "my/device/name", 57 | } 58 | ) 59 | assert "my/device/name" == device.name 60 | assert {} == device.parent_relations 61 | 62 | 63 | def test_empty_parent_relations( 64 | fake_device: Callable[[Dict[str, Any]], Device] 65 | ) -> None: 66 | device = fake_device( 67 | { 68 | "name": "my/device/name", 69 | "parentRelations": [], 70 | } 71 | ) 72 | assert "my/device/name" == device.name 73 | assert {} == device.parent_relations 74 | 75 | 76 | def test_invalid_parent_relations( 77 | fake_device: Callable[[Dict[str, Any]], Device] 78 | ) -> None: 79 | """Invalid parentRelations should be ignored.""" 80 | device = fake_device( 81 | { 82 | "name": "my/device/name", 83 | "parentRelations": [{}], 84 | } 85 | ) 86 | assert "my/device/name" == device.name 87 | assert {} == device.parent_relations 88 | 89 | 90 | def test_parent_relation(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 91 | device = fake_device( 92 | { 93 | "name": "my/device/name", 94 | "parentRelations": [ 95 | { 96 | "parent": "my/structure/or/room", 97 | "displayName": "Some Name", 98 | }, 99 | ], 100 | } 101 | ) 102 | assert "my/device/name" == device.name 103 | assert {"my/structure/or/room": "Some Name"} == device.parent_relations 104 | 105 | assert_diagnostics( 106 | device.get_diagnostics(), 107 | { 108 | "data": { 109 | "name": "**REDACTED**", 110 | "parentRelations": [ 111 | { 112 | "parent": "**REDACTED**", 113 | "displayName": "**REDACTED**", 114 | } 115 | ], 116 | }, 117 | }, 118 | ) 119 | 120 | 121 | def test_multiple_parent_relations( 122 | fake_device: Callable[[Dict[str, Any]], Device] 123 | ) -> None: 124 | device = fake_device( 125 | { 126 | "name": "my/device/name", 127 | "parentRelations": [ 128 | { 129 | "parent": "my/structure/or/room1", 130 | "displayName": "Some Name1", 131 | }, 132 | { 133 | "parent": "my/structure/or/room2", 134 | "displayName": "Some Name2", 135 | }, 136 | ], 137 | } 138 | ) 139 | assert "my/device/name" == device.name 140 | assert { 141 | "my/structure/or/room1": "Some Name1", 142 | "my/structure/or/room2": "Some Name2", 143 | } == device.parent_relations 144 | -------------------------------------------------------------------------------- /tests/test_device_traits.py: -------------------------------------------------------------------------------- 1 | """Tests for device traits.""" 2 | 3 | import datetime 4 | from typing import Any, Callable, Dict 5 | 6 | import pytest 7 | 8 | from google_nest_sdm.device import Device 9 | 10 | from .conftest import assert_diagnostics 11 | 12 | 13 | def test_info_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 14 | device = fake_device( 15 | { 16 | "name": "my/device/name", 17 | "traits": { 18 | "sdm.devices.traits.Info": { 19 | "customName": "Device Name", 20 | }, 21 | }, 22 | } 23 | ) 24 | assert "my/device/name" == device.name 25 | assert "sdm.devices.traits.Info" in device.traits 26 | trait = device.traits["sdm.devices.traits.Info"] 27 | assert "Device Name" == trait.custom_name 28 | 29 | 30 | assert_diagnostics( 31 | device.get_diagnostics(), 32 | { 33 | "data": { 34 | "name": "**REDACTED**", 35 | "parentRelations": [], 36 | "traits": { 37 | "sdm.devices.traits.Info": { 38 | "custom_name": "**REDACTED**", 39 | } 40 | }, 41 | }, 42 | }, 43 | ) 44 | 45 | 46 | 47 | def test_connectivity_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 48 | device = fake_device( 49 | { 50 | "name": "my/device/name", 51 | "traits": { 52 | "sdm.devices.traits.Connectivity": { 53 | "status": "OFFLINE", 54 | }, 55 | }, 56 | } 57 | ) 58 | assert "sdm.devices.traits.Connectivity" in device.traits 59 | trait = device.traits["sdm.devices.traits.Connectivity"] 60 | assert "OFFLINE" == trait.status 61 | 62 | 63 | def test_fan_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 64 | device = fake_device( 65 | { 66 | "name": "my/device/name", 67 | "traits": { 68 | "sdm.devices.traits.Fan": { 69 | "timerMode": "ON", 70 | "timerTimeout": "2019-05-10T03:22:54Z", 71 | }, 72 | }, 73 | } 74 | ) 75 | assert "sdm.devices.traits.Fan" in device.traits 76 | trait = device.traits["sdm.devices.traits.Fan"] 77 | assert "ON" == trait.timer_mode 78 | assert ( 79 | datetime.datetime(2019, 5, 10, 3, 22, 54, tzinfo=datetime.timezone.utc) 80 | == trait.timer_timeout 81 | ) 82 | 83 | 84 | def test_fan_traits_empty(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 85 | device = fake_device( 86 | { 87 | "name": "my/device/name", 88 | "traits": { 89 | "sdm.devices.traits.Fan": {}, 90 | }, 91 | } 92 | ) 93 | assert "sdm.devices.traits.Fan" in device.traits 94 | trait = device.traits["sdm.devices.traits.Fan"] 95 | assert trait.timer_mode is None 96 | assert trait.timer_timeout is None 97 | 98 | 99 | def test_humidity_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 100 | device = fake_device( 101 | { 102 | "name": "my/device/name", 103 | "traits": { 104 | "sdm.devices.traits.Humidity": { 105 | "ambientHumidityPercent": 25.3, 106 | }, 107 | }, 108 | } 109 | ) 110 | assert "sdm.devices.traits.Humidity" in device.traits 111 | trait = device.traits["sdm.devices.traits.Humidity"] 112 | assert 25.3 == trait.ambient_humidity_percent 113 | 114 | 115 | def test_humidity_int_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 116 | device = fake_device( 117 | { 118 | "name": "my/device/name", 119 | "traits": { 120 | "sdm.devices.traits.Humidity": { 121 | "ambientHumidityPercent": 25, 122 | }, 123 | }, 124 | } 125 | ) 126 | assert "sdm.devices.traits.Humidity" in device.traits 127 | trait = device.traits["sdm.devices.traits.Humidity"] 128 | assert 25 == trait.ambient_humidity_percent 129 | 130 | 131 | def test_temperature_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 132 | device = fake_device( 133 | { 134 | "name": "my/device/name", 135 | "traits": { 136 | "sdm.devices.traits.Temperature": { 137 | "ambientTemperatureCelsius": 31.1, 138 | }, 139 | }, 140 | } 141 | ) 142 | assert "sdm.devices.traits.Temperature" in device.traits 143 | trait = device.traits["sdm.devices.traits.Temperature"] 144 | assert 31.1 == trait.ambient_temperature_celsius 145 | 146 | 147 | def test_multiple_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 148 | device = fake_device( 149 | { 150 | "name": "my/device/name", 151 | "type": "sdm.devices.types.SomeDeviceType", 152 | "traits": { 153 | "sdm.devices.traits.Info": { 154 | "customName": "Device Name", 155 | }, 156 | "sdm.devices.traits.Connectivity": { 157 | "status": "OFFLINE", 158 | }, 159 | }, 160 | } 161 | ) 162 | assert "my/device/name" == device.name 163 | assert "sdm.devices.types.SomeDeviceType" == device.type 164 | assert "sdm.devices.traits.Info" in device.traits 165 | trait = device.traits["sdm.devices.traits.Info"] 166 | assert "Device Name" == trait.custom_name 167 | assert "sdm.devices.traits.Connectivity" in device.traits 168 | trait = device.traits["sdm.devices.traits.Connectivity"] 169 | assert "OFFLINE" == trait.status 170 | 171 | 172 | def test_info_traits_type_error( 173 | fake_device: Callable[[Dict[str, Any]], Device] 174 | ) -> None: 175 | device = fake_device( 176 | { 177 | "name": "my/device/name", 178 | "traits": { 179 | "sdm.devices.traits.Info": { 180 | "customName": 12345, 181 | }, 182 | }, 183 | } 184 | ) 185 | assert "my/device/name" == device.name 186 | assert "sdm.devices.traits.Info" in device.traits 187 | trait = device.traits["sdm.devices.traits.Info"] 188 | assert trait.custom_name == "12345" 189 | 190 | 191 | def test_info_traits_missing_optional_field( 192 | fake_device: Callable[[Dict[str, Any]], Device] 193 | ) -> None: 194 | device = fake_device( 195 | { 196 | "name": "my/device/name", 197 | "traits": { 198 | "sdm.devices.traits.Info": {}, 199 | }, 200 | } 201 | ) 202 | assert "my/device/name" == device.name 203 | assert "sdm.devices.traits.Info" in device.traits 204 | trait = device.traits["sdm.devices.traits.Info"] 205 | assert trait.custom_name is None 206 | 207 | 208 | def test_connectivity_traits_missing_required_field( 209 | fake_device: Callable[[Dict[str, Any]], Device] 210 | ) -> None: 211 | with pytest.raises(ValueError): 212 | fake_device( 213 | { 214 | "name": "my/device/name", 215 | "traits": { 216 | "sdm.devices.traits.Connectivity": {}, 217 | }, 218 | } 219 | ) 220 | -------------------------------------------------------------------------------- /tests/test_doorbell_traits.py: -------------------------------------------------------------------------------- 1 | """Tests for doorbell traits.""" 2 | 3 | from typing import Any, Callable, Dict 4 | 5 | from google_nest_sdm.device import Device 6 | 7 | 8 | def test_doorbell_chime(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 9 | device = fake_device( 10 | { 11 | "name": "my/device/name", 12 | "traits": { 13 | "sdm.devices.traits.DoorbellChime": {}, 14 | }, 15 | } 16 | ) 17 | assert device.traits.keys() == {"sdm.devices.traits.DoorbellChime"} 18 | 19 | 20 | def test_doorbell_chime_trait_hack( 21 | fake_device: Callable[[Dict[str, Any]], Device] 22 | ) -> None: 23 | """Adds the DoorbellChime trait even when missing from the API to fix an API bug.""" 24 | device = fake_device( 25 | { 26 | "name": "my/device/name", 27 | "type": "sdm.devices.types.DOORBELL", 28 | "traits": {}, 29 | } 30 | ) 31 | assert device.type == "sdm.devices.types.DOORBELL" 32 | assert device.traits.keys() == {"sdm.devices.traits.DoorbellChime"} 33 | 34 | 35 | def test_doorbell_chime_trait_hack_empty_traits( 36 | fake_device: Callable[[Dict[str, Any]], Device] 37 | ) -> None: 38 | """Adds the DoorbellChime trait even when missing from the API to fix an API bug.""" 39 | device = fake_device( 40 | { 41 | "name": "my/device/name", 42 | "type": "sdm.devices.types.DOORBELL", 43 | } 44 | ) 45 | assert device.type == "sdm.devices.types.DOORBELL" 46 | assert device.traits.keys() == {"sdm.devices.traits.DoorbellChime"} 47 | 48 | 49 | def test_doorbell_chime_trait_hack_not_applied( 50 | fake_device: Callable[[Dict[str, Any]], Device] 51 | ) -> None: 52 | """The doorbell chime trait hack is not applied for other types.""" 53 | device = fake_device( 54 | { 55 | "name": "my/device/name", 56 | "type": "sdm.devices.types.CAMERA", 57 | "traits": {}, 58 | } 59 | ) 60 | assert device.type == "sdm.devices.types.CAMERA" 61 | assert device.traits.keys() == set() 62 | -------------------------------------------------------------------------------- /tests/test_structure.py: -------------------------------------------------------------------------------- 1 | from google_nest_sdm.structure import Structure 2 | 3 | 4 | def test_no_traits() -> None: 5 | raw = { 6 | "name": "my/structure/name", 7 | } 8 | structure = Structure.MakeStructure(raw) 9 | assert "my/structure/name" == structure.name 10 | assert "sdm.structures.traits.Info" not in structure.traits 11 | 12 | 13 | def test_empty_traits() -> None: 14 | raw = { 15 | "name": "my/structure/name", 16 | "traits": {}, 17 | } 18 | structure = Structure.MakeStructure(raw) 19 | assert "my/structure/name" == structure.name 20 | assert "sdm.structures.traits.Info" not in structure.traits 21 | 22 | 23 | def test_info_traits() -> None: 24 | raw = { 25 | "name": "my/structure/name", 26 | "traits": { 27 | "sdm.structures.traits.Info": { 28 | "customName": "Structure Name", 29 | }, 30 | }, 31 | } 32 | structure = Structure.MakeStructure(raw) 33 | assert "my/structure/name" == structure.name 34 | assert "sdm.structures.traits.Info" in structure.traits 35 | trait = structure.traits["sdm.structures.traits.Info"] 36 | assert "Structure Name" == trait.custom_name 37 | 38 | 39 | def test_room_info_traits() -> None: 40 | raw = { 41 | "name": "my/structure/name", 42 | "traits": { 43 | "sdm.structures.traits.RoomInfo": { 44 | "customName": "Structure Name", 45 | }, 46 | }, 47 | } 48 | structure = Structure.MakeStructure(raw) 49 | assert "my/structure/name" == structure.name 50 | assert "sdm.structures.traits.RoomInfo" in structure.traits 51 | trait = structure.traits["sdm.structures.traits.RoomInfo"] 52 | assert "Structure Name" == trait.custom_name 53 | -------------------------------------------------------------------------------- /tests/test_subscriber_client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | from unittest.mock import Mock, AsyncMock, patch 5 | 6 | import pytest 7 | from google.auth.exceptions import RefreshError, GoogleAuthError, TransportError 8 | from google.api_core.exceptions import GoogleAPIError, NotFound, Unauthenticated 9 | 10 | from google_nest_sdm.exceptions import ( 11 | AuthException, 12 | SubscriberException, 13 | ConfigurationException, 14 | ) 15 | from google_nest_sdm.subscriber_client import ( 16 | refresh_creds, 17 | SubscriberClient, 18 | pull_request_generator, 19 | ) 20 | 21 | 22 | async def test_refresh_creds() -> None: 23 | """Test low level refresh errors.""" 24 | mock_refresh = Mock() 25 | mock_creds = Mock() 26 | mock_creds.refresh = mock_refresh 27 | refresh_creds(mock_creds) 28 | assert mock_refresh.call_count == 1 29 | 30 | 31 | @pytest.mark.parametrize( 32 | ("raised", "expected"), 33 | [ 34 | (RefreshError(), AuthException), 35 | (TransportError(), SubscriberException), 36 | (GoogleAuthError(), SubscriberException), 37 | ], 38 | ) 39 | async def test_refresh_creds_error(raised: Exception, expected: Any) -> None: 40 | """Test low level refresh errors.""" 41 | mock_refresh = Mock() 42 | mock_refresh.side_effect = raised 43 | mock_creds = Mock() 44 | mock_creds.refresh = mock_refresh 45 | with pytest.raises(expected): 46 | refresh_creds(mock_creds) 47 | 48 | 49 | async def test_ack_no_messages() -> None: 50 | """Test ack with no messages to ack is a no-op.""" 51 | 52 | client = SubscriberClient(auth=AsyncMock(), subscription_name="test") 53 | await client.ack_messages([]) 54 | 55 | 56 | async def test_ack_messages() -> None: 57 | """Test ack messages.""" 58 | 59 | client = SubscriberClient(auth=AsyncMock(), subscription_name="test") 60 | with patch( 61 | "google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient" 62 | ) as mock_client: 63 | mock_acknowledge = AsyncMock() 64 | mock_acknowledge.return_value = None 65 | mock_client.return_value.acknowledge = mock_acknowledge 66 | await client.ack_messages(["message1", "message2"]) 67 | 68 | # Verify that acknowledge was called with the correct arguments 69 | mock_acknowledge.assert_awaited_once_with( 70 | subscription="test", ack_ids=["message1", "message2"] 71 | ) 72 | 73 | 74 | async def test_streaming_pull() -> None: 75 | """Test streaming pull call.""" 76 | 77 | client = SubscriberClient(auth=AsyncMock(), subscription_name="test") 78 | with patch( 79 | "google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient" 80 | ) as mock_client: 81 | mock_streaming_pull = AsyncMock() 82 | mock_streaming_pull.return_value = None 83 | mock_client.return_value.streaming_pull = mock_streaming_pull 84 | await client.streaming_pull(lambda: []) 85 | 86 | # Verify the call was invoked with the correct arguments 87 | mock_streaming_pull.assert_awaited_once() 88 | 89 | 90 | @pytest.mark.parametrize( 91 | ("raised", "expected", "message"), 92 | [ 93 | ( 94 | NotFound("my error"), # type: ignore[no-untyped-call] 95 | ConfigurationException, 96 | "NotFound error calling streaming_pull: 404 my error", 97 | ), 98 | ( 99 | GoogleAPIError("my error"), 100 | SubscriberException, 101 | "API error when calling streaming_pull: my error", 102 | ), 103 | ( 104 | Unauthenticated("auth error"), # type: ignore[no-untyped-call] 105 | AuthException, 106 | "Failed to authenticate streaming_pull: 401 auth error", 107 | ), 108 | ( 109 | Exception("my error"), 110 | SubscriberException, 111 | "Unexpected error when calling streaming_pull: my error", 112 | ), 113 | ], 114 | ) 115 | async def test_streaming_pull_failure( 116 | raised: Exception, expected: Any, message: str 117 | ) -> None: 118 | """Test ack messages.""" 119 | 120 | client = SubscriberClient(auth=AsyncMock(), subscription_name="test") 121 | with patch( 122 | "google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient" 123 | ) as mock_client: 124 | mock_streaming_pull = AsyncMock() 125 | mock_streaming_pull.side_effect = raised 126 | mock_client.return_value.streaming_pull = mock_streaming_pull 127 | 128 | with pytest.raises(expected, match=message): 129 | await client.streaming_pull(lambda: []) 130 | 131 | 132 | async def test_request_generator() -> None: 133 | """Test the streaming pull request generator.""" 134 | ack_ids = [ 135 | ["ack-id-1", "ack-id-2"], 136 | ["ack-id-3", "ack-id-4"], 137 | [], 138 | ] 139 | with patch("asyncio.sleep", return_value=None): 140 | stream = pull_request_generator( 141 | "projects/some-project-id/subscriptions/sub-1", lambda: ack_ids.pop(0) 142 | ) 143 | stream_iter = aiter(stream) 144 | request = await anext(stream_iter) 145 | assert request.subscription == "projects/some-project-id/subscriptions/sub-1" 146 | assert request.stream_ack_deadline_seconds == 180 147 | assert not request.ack_ids 148 | 149 | request = await anext(stream_iter) 150 | assert request.subscription == "" 151 | assert request.stream_ack_deadline_seconds == 180 152 | assert request.ack_ids == ["ack-id-1", "ack-id-2"] 153 | 154 | request = await anext(stream_iter) 155 | assert request.subscription == "" 156 | assert request.stream_ack_deadline_seconds == 180 157 | assert request.ack_ids == ["ack-id-3", "ack-id-4"] 158 | 159 | await stream.aclose() 160 | -------------------------------------------------------------------------------- /tests/test_thermostat_traits.py: -------------------------------------------------------------------------------- 1 | """Tests for thermostat traits.""" 2 | 3 | from typing import Any, Awaitable, Callable, Dict 4 | 5 | import aiohttp 6 | import pytest 7 | 8 | from google_nest_sdm import google_nest_api 9 | from google_nest_sdm.device import Device 10 | 11 | from .conftest import DeviceHandler, NewHandler, Recorder 12 | 13 | 14 | def test_thermostat_eco_traits(fake_device: Callable[[Dict[str, Any]], Device]) -> None: 15 | device = fake_device( 16 | { 17 | "name": "my/device/name", 18 | "traits": { 19 | "sdm.devices.traits.ThermostatEco": { 20 | "availableModes": ["MANUAL_ECHO", "OFF"], 21 | "mode": "MANUAL_ECHO", 22 | "heatCelsius": 20.0, 23 | "coolCelsius": 22.0, 24 | }, 25 | }, 26 | } 27 | ) 28 | assert "sdm.devices.traits.ThermostatEco" in device.traits 29 | trait = device.traits["sdm.devices.traits.ThermostatEco"] 30 | assert ["MANUAL_ECHO", "OFF"] == trait.available_modes 31 | assert "MANUAL_ECHO" == trait.mode 32 | assert 20.0 == trait.heat_celsius 33 | assert 22.0 == trait.cool_celsius 34 | 35 | 36 | def test_thermostat_hvac_traits( 37 | fake_device: Callable[[Dict[str, Any]], Device] 38 | ) -> None: 39 | device = fake_device( 40 | { 41 | "name": "my/device/name", 42 | "traits": { 43 | "sdm.devices.traits.ThermostatHvac": { 44 | "status": "HEATING", 45 | }, 46 | }, 47 | } 48 | ) 49 | assert "sdm.devices.traits.ThermostatHvac" in device.traits 50 | trait = device.traits["sdm.devices.traits.ThermostatHvac"] 51 | assert "HEATING" == trait.status 52 | 53 | 54 | def test_thermostat_mode_traits( 55 | fake_device: Callable[[Dict[str, Any]], Device] 56 | ) -> None: 57 | device = fake_device( 58 | { 59 | "name": "my/device/name", 60 | "traits": { 61 | "sdm.devices.traits.ThermostatMode": { 62 | "availableModes": ["HEAT", "COOL", "HEATCOOL", "OFF"], 63 | "mode": "COOL", 64 | }, 65 | }, 66 | } 67 | ) 68 | assert "sdm.devices.traits.ThermostatMode" in device.traits 69 | trait = device.traits["sdm.devices.traits.ThermostatMode"] 70 | assert ["HEAT", "COOL", "HEATCOOL", "OFF"] == trait.available_modes 71 | assert "COOL" == trait.mode 72 | 73 | 74 | def test_thermostat_temperature_setpoint_traits( 75 | fake_device: Callable[[Dict[str, Any]], Device] 76 | ) -> None: 77 | device = fake_device( 78 | { 79 | "name": "my/device/name", 80 | "traits": { 81 | "sdm.devices.traits.ThermostatTemperatureSetpoint": { 82 | "heatCelsius": 20.0, 83 | "coolCelsius": 22.0, 84 | }, 85 | }, 86 | } 87 | ) 88 | assert "sdm.devices.traits.ThermostatTemperatureSetpoint" in device.traits 89 | trait = device.traits["sdm.devices.traits.ThermostatTemperatureSetpoint"] 90 | assert 20.0 == trait.heat_celsius 91 | assert 22.0 == trait.cool_celsius 92 | 93 | 94 | @pytest.mark.parametrize( 95 | "data", 96 | [ 97 | ({}), 98 | ({"heatCelsius": 20.0}), 99 | ({"coolCelsius": 22.0}), 100 | ({"heatCelsius": 20.0, "coolCelsius": 22.0}), 101 | ], 102 | ) 103 | def test_thermostat_temperature_setpoint_optional_fields( 104 | fake_device: Callable[[Dict[str, Any]], Device], data: dict[str, Any] 105 | ) -> None: 106 | device = fake_device( 107 | { 108 | "name": "my/device/name", 109 | "traits": {"sdm.devices.traits.ThermostatTemperatureSetpoint": data}, 110 | } 111 | ) 112 | assert "sdm.devices.traits.ThermostatTemperatureSetpoint" in device.traits 113 | assert device.thermostat_temperature_setpoint 114 | 115 | 116 | def test_thermostat_multiple_traits( 117 | fake_device: Callable[[Dict[str, Any]], Device] 118 | ) -> None: 119 | device = fake_device( 120 | { 121 | "name": "my/device/name", 122 | "traits": { 123 | "sdm.devices.traits.ThermostatEco": { 124 | "availableModes": ["MANUAL_ECHO", "OFF"], 125 | "mode": "MANUAL_ECHO", 126 | "heatCelsius": 21.0, 127 | "coolCelsius": 22.0, 128 | }, 129 | "sdm.devices.traits.ThermostatHvac": { 130 | "status": "HEATING", 131 | }, 132 | "sdm.devices.traits.ThermostatMode": { 133 | "availableModes": ["HEAT", "COOL", "HEATCOOL", "OFF"], 134 | "mode": "COOL", 135 | }, 136 | "sdm.devices.traits.ThermostatTemperatureSetpoint": { 137 | "heatCelsius": 23.0, 138 | "coolCelsius": 24.0, 139 | }, 140 | }, 141 | } 142 | ) 143 | assert "sdm.devices.traits.ThermostatEco" in device.traits 144 | assert "sdm.devices.traits.ThermostatHvac" in device.traits 145 | assert "sdm.devices.traits.ThermostatMode" in device.traits 146 | assert "sdm.devices.traits.ThermostatTemperatureSetpoint" in device.traits 147 | trait = device.traits["sdm.devices.traits.ThermostatEco"] 148 | assert ["MANUAL_ECHO", "OFF"] == trait.available_modes 149 | assert "MANUAL_ECHO" == trait.mode 150 | assert 21.0 == trait.heat_celsius 151 | assert 22.0 == trait.cool_celsius 152 | trait = device.traits["sdm.devices.traits.ThermostatHvac"] 153 | assert "HEATING" == trait.status 154 | trait = device.traits["sdm.devices.traits.ThermostatMode"] 155 | assert ["HEAT", "COOL", "HEATCOOL", "OFF"] == trait.available_modes 156 | assert "COOL" == trait.mode 157 | trait = device.traits["sdm.devices.traits.ThermostatTemperatureSetpoint"] 158 | assert 23.0 == trait.heat_celsius 159 | assert 24.0 == trait.cool_celsius 160 | 161 | 162 | @pytest.mark.parametrize( 163 | "data", 164 | [ 165 | ({}), 166 | ({"mode": "OFF"}), 167 | ], 168 | ) 169 | def test_thermostat_eco_optional_fields( 170 | fake_device: Callable[[Dict[str, Any]], Device], data: dict[str, Any] 171 | ) -> None: 172 | device = fake_device( 173 | { 174 | "name": "my/device/name", 175 | "traits": {"sdm.devices.traits.ThermostatEco": data}, 176 | } 177 | ) 178 | assert "sdm.devices.traits.ThermostatEco" in device.traits 179 | assert device.thermostat_eco 180 | assert device.thermostat_eco.mode == "OFF" 181 | 182 | 183 | async def test_fan_set_timer( 184 | app: aiohttp.web.Application, 185 | recorder: Recorder, 186 | device_handler: DeviceHandler, 187 | api_client: Callable[[], Awaitable[google_nest_api.GoogleNestAPI]], 188 | ) -> None: 189 | device_id = device_handler.add_device( 190 | traits={ 191 | "sdm.devices.traits.Fan": { 192 | "timerMode": "OFF", 193 | }, 194 | } 195 | ) 196 | post_handler = NewHandler(recorder, [{}]) 197 | app.router.add_post(f"/{device_id}:executeCommand", post_handler) 198 | 199 | api = await api_client() 200 | devices = await api.async_get_devices() 201 | assert len(devices) == 1 202 | device = devices[0] 203 | assert device_id == device.name 204 | trait = device.traits["sdm.devices.traits.Fan"] 205 | assert trait.timer_mode == "OFF" 206 | await trait.set_timer("ON", 3600) 207 | assert recorder.request == { 208 | "command": "sdm.devices.commands.Fan.SetTimer", 209 | "params": { 210 | "timerMode": "ON", 211 | "duration": "3600s", 212 | }, 213 | } 214 | 215 | 216 | async def test_thermostat_eco_set_mode( 217 | app: aiohttp.web.Application, 218 | recorder: Recorder, 219 | device_handler: DeviceHandler, 220 | api_client: Callable[[], Awaitable[google_nest_api.GoogleNestAPI]], 221 | ) -> None: 222 | device_id = device_handler.add_device( 223 | traits={ 224 | "sdm.devices.traits.ThermostatEco": { 225 | "availableModes": ["MANUAL_ECO", "OFF"], 226 | "mode": "MANUAL_ECO", 227 | "heatCelsius": 20.0, 228 | "coolCelsius": 22.0, 229 | }, 230 | } 231 | ) 232 | post_handler = NewHandler(recorder, [{}]) 233 | app.router.add_post(f"/{device_id}:executeCommand", post_handler) 234 | 235 | api = await api_client() 236 | devices = await api.async_get_devices() 237 | assert len(devices) == 1 238 | device = devices[0] 239 | assert device.name == device_id 240 | trait = device.traits["sdm.devices.traits.ThermostatEco"] 241 | assert trait.mode == "MANUAL_ECO" 242 | await trait.set_mode("OFF") 243 | assert recorder.request == { 244 | "command": "sdm.devices.commands.ThermostatEco.SetMode", 245 | "params": {"mode": "OFF"}, 246 | } 247 | 248 | 249 | async def test_thermostat_mode_set_mode( 250 | app: aiohttp.web.Application, 251 | recorder: Recorder, 252 | device_handler: DeviceHandler, 253 | api_client: Callable[[], Awaitable[google_nest_api.GoogleNestAPI]], 254 | ) -> None: 255 | device_id = device_handler.add_device( 256 | traits={ 257 | "sdm.devices.traits.ThermostatMode": { 258 | "availableModes": ["HEAT", "COOL", "HEATCOOL", "OFF"], 259 | "mode": "COOL", 260 | }, 261 | } 262 | ) 263 | post_handler = NewHandler(recorder, [{}]) 264 | app.router.add_post(f"/{device_id}:executeCommand", post_handler) 265 | 266 | api = await api_client() 267 | devices = await api.async_get_devices() 268 | assert len(devices) == 1 269 | device = devices[0] 270 | assert device.name == device_id 271 | trait = device.traits["sdm.devices.traits.ThermostatMode"] 272 | assert trait.mode == "COOL" 273 | await trait.set_mode("HEAT") 274 | assert recorder.request == { 275 | "command": "sdm.devices.commands.ThermostatMode.SetMode", 276 | "params": {"mode": "HEAT"}, 277 | } 278 | 279 | 280 | async def test_thermostat_temperature_set_point( 281 | app: aiohttp.web.Application, 282 | recorder: Recorder, 283 | device_handler: DeviceHandler, 284 | api_client: Callable[[], Awaitable[google_nest_api.GoogleNestAPI]], 285 | ) -> None: 286 | device_id = device_handler.add_device( 287 | traits={ 288 | "sdm.devices.traits.ThermostatTemperatureSetpoint": { 289 | "heatCelsius": 23.0, 290 | "coolCelsius": 24.0, 291 | }, 292 | } 293 | ) 294 | post_handler = NewHandler(recorder, [{}, {}, {}]) 295 | app.router.add_post(f"/{device_id}:executeCommand", post_handler) 296 | 297 | api = await api_client() 298 | devices = await api.async_get_devices() 299 | assert len(devices) == 1 300 | device = devices[0] 301 | assert device.name == device_id 302 | trait = device.traits["sdm.devices.traits.ThermostatTemperatureSetpoint"] 303 | assert trait.heat_celsius == 23.0 304 | assert trait.cool_celsius == 24.0 305 | await trait.set_heat(25.0) 306 | assert recorder.request == { 307 | "command": "sdm.devices.commands.ThermostatTemperatureSetpoint.SetHeat", 308 | "params": {"heatCelsius": 25.0}, 309 | } 310 | 311 | await trait.set_cool(26.0) 312 | assert recorder.request == { 313 | "command": "sdm.devices.commands.ThermostatTemperatureSetpoint.SetCool", 314 | "params": {"coolCelsius": 26.0}, 315 | } 316 | 317 | await trait.set_range(27.0, 28.0) 318 | assert recorder.request == { 319 | "command": "sdm.devices.commands.ThermostatTemperatureSetpoint.SetRange", 320 | "params": { 321 | "heatCelsius": 27.0, 322 | "coolCelsius": 28.0, 323 | }, 324 | } 325 | -------------------------------------------------------------------------------- /tests/test_transcoder.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from unittest.mock import Mock, patch 3 | 4 | import pytest 5 | 6 | from google_nest_sdm.exceptions import TranscodeException 7 | from google_nest_sdm.transcoder import Transcoder 8 | 9 | BINARY = "/bin/true" 10 | 11 | 12 | async def test_transcoder_file_not_exist(tmp_path: str) -> None: 13 | t = Transcoder(BINARY, path_prefix=tmp_path) 14 | with pytest.raises(TranscodeException): 15 | await t.transcode_clip("in_file.mp4", "out_file.gif") 16 | 17 | 18 | async def test_transcoder_output_already_exists(tmp_path: str) -> None: 19 | t = Transcoder(BINARY, path_prefix=tmp_path) 20 | with open(f"{tmp_path}/in_file.mp4", mode="w") as f: 21 | f.write("some-input") 22 | with open(f"{tmp_path}/out_file.gif", mode="w") as f: 23 | f.write("some-output") 24 | with pytest.raises(TranscodeException): 25 | await t.transcode_clip("in_file.mp4", "out_file.gif") 26 | 27 | 28 | async def test_transcoder(tmp_path: str) -> None: 29 | t = Transcoder(BINARY, path_prefix=tmp_path) 30 | with open(f"{tmp_path}/in_file.mp4", mode="w") as f: 31 | f.write("some-input") 32 | with patch( 33 | "google_nest_sdm.transcoder.asyncio.create_subprocess_shell" 34 | ) as mock_shell: 35 | process_mock = Mock() 36 | future: asyncio.Future = asyncio.Future() 37 | future.set_result(("", "")) 38 | process_mock.communicate.return_value = future 39 | process_mock.returncode = 0 40 | mock_shell.return_value = process_mock 41 | await t.transcode_clip("in_file.mp4", "out_file.gif") 42 | 43 | 44 | async def test_transcoder_failure(tmp_path: str) -> None: 45 | t = Transcoder("/bin/false", path_prefix=tmp_path) 46 | with open(f"{tmp_path}/in_file.mp4", mode="w") as f: 47 | f.write("some-input") 48 | with patch( 49 | "google_nest_sdm.transcoder.asyncio.create_subprocess_shell" 50 | ) as mock_shell, pytest.raises(TranscodeException): 51 | process_mock = Mock() 52 | future: asyncio.Future = asyncio.Future() 53 | future.set_result(("", "")) 54 | process_mock.communicate.return_value = future 55 | process_mock.returncode = 1 56 | mock_shell.return_value = process_mock 57 | await t.transcode_clip("in_file.mp4", "out_file.gif") 58 | -------------------------------------------------------------------------------- /tests/test_webrtc_util.py: -------------------------------------------------------------------------------- 1 | """Tests for WebRTC utility.""" 2 | 3 | from google_nest_sdm.webrtc_util import ( 4 | SDPDirection, 5 | SDPMediaKind, 6 | _add_foundation_to_candidates, 7 | _get_media_direction, 8 | _update_direction_in_answer, 9 | fix_mozilla_sdp_answer, 10 | ) 11 | 12 | 13 | def test_fix_mozilla_sdp_answer() -> None: 14 | """Test the fix in the SDP for Firefox.""" 15 | firefox_offer_sdp = ( 16 | "v=0\r\n" 17 | "o=mozilla...THIS_IS_SDPARTA-99.0 137092584186714854 0 IN IP4 0.0.0.0\r\n" 18 | "m=audio 9 UDP/TLS/RTP/SAVPF 109 9 0 8 101\r\n" 19 | "c=IN IP4 0.0.0.0\r\n" 20 | "a=recvonly\r\n" 21 | "m=video 9 UDP/TLS/RTP/SAVPF 120 124 121 125 126 127 97 98 123 122 119\r\n" 22 | "c=IN IP4 0.0.0.0\r\n" 23 | "a=recvonly\r\n" 24 | "m=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\n" 25 | "c=IN IP4 0.0.0.0\r\n" 26 | "a=sendrecv\r\n" 27 | ) 28 | answer_sdp = ( 29 | "v=0\r\n" 30 | "o=- 0 2 IN IP4 127.0.0.1\r\n" 31 | "m=audio 19305 UDP/TLS/RTP/SAVPF 109\r\n" 32 | "c=IN IP4 74.125.247.118\r\n" 33 | "a=rtcp:9 IN IP4 0.0.0.0\r\n" 34 | "a=candidate: 1 udp 2113939711 2001:4860:4864:4::118 19305 typ host generation 0\r\n" 35 | "a=candidate: 1 tcp 2113939710 2001:4860:4864:4::118 19305 typ host tcptype passive generation 0\r\n" 36 | "a=candidate: 1 ssltcp 2113939709 2001:4860:4864:4::118 443 typ host generation 0\r\n" 37 | "a=candidate: 1 udp 2113932031 74.125.247.118 19305 typ host generation 0\r\n" 38 | "a=candidate: 1 tcp 2113932030 74.125.247.118 19305 typ host tcptype passive generation 0\r\n" 39 | "a=candidate: 1 ssltcp 2113932029 74.125.247.118 443 typ host generation 0\r\n" 40 | "a=sendrecv\r\n" 41 | "m=video 9 UDP/TLS/RTP/SAVPF 126 127\r\n" 42 | "c=IN IP4 0.0.0.0\r\n" 43 | "a=rtcp:9 IN IP4 0.0.0.0\r\n" 44 | "a=sendrecv\r\n" 45 | "m=application 9 DTLS/SCTP 5000\r\n" 46 | "c=IN IP4 0.0.0.0\r\n" 47 | ) 48 | expected_answer_sdp = ( 49 | "v=0\r\n" 50 | "o=- 0 2 IN IP4 127.0.0.1\r\n" 51 | "m=audio 19305 UDP/TLS/RTP/SAVPF 109\r\n" 52 | "c=IN IP4 74.125.247.118\r\n" 53 | "a=rtcp:9 IN IP4 0.0.0.0\r\n" 54 | "a=candidate:1 1 udp 2113939711 2001:4860:4864:4::118 19305 typ host generation 0\r\n" 55 | "a=candidate:2 1 tcp 2113939710 2001:4860:4864:4::118 19305 typ host tcptype passive generation 0\r\n" 56 | "a=candidate:3 1 ssltcp 2113939709 2001:4860:4864:4::118 443 typ host generation 0\r\n" 57 | "a=candidate:4 1 udp 2113932031 74.125.247.118 19305 typ host generation 0\r\n" 58 | "a=candidate:5 1 tcp 2113932030 74.125.247.118 19305 typ host tcptype passive generation 0\r\n" 59 | "a=candidate:6 1 ssltcp 2113932029 74.125.247.118 443 typ host generation 0\r\n" 60 | "a=sendonly\r\n" 61 | "m=video 9 UDP/TLS/RTP/SAVPF 126 127\r\n" 62 | "c=IN IP4 0.0.0.0\r\n" 63 | "a=rtcp:9 IN IP4 0.0.0.0\r\n" 64 | "a=sendonly\r\n" 65 | "m=application 9 DTLS/SCTP 5000\r\n" 66 | "c=IN IP4 0.0.0.0\r\n" 67 | ) 68 | chrome_offer_sdp = ( 69 | "v=0\r\n" 70 | "o=- 6714414228100263102 2 IN IP4 127.0.0.1\r\n" 71 | "m=audio 9 UDP/TLS/RTP/SAVPF 111 63 9 0 8 13 110 126\r\n" 72 | "c=IN IP4 0.0.0.0\r\n" 73 | "a=recvonly\r\n" 74 | "m=video 9 UDP/TLS/RTP/SAVPF 96 97 98 99 100 101 35 36 37 38 102 103 104 105 106 107 108 109 127 125 39 40 41 42 43 44 45 46 47 48 112 113 114 115 116 117 118 49\r\n" 75 | "c=IN IP4 0.0.0.0\r\n" 76 | "a=recvonly\r\n" 77 | "m=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\n" 78 | "c=IN IP4 0.0.0.0\r\n" 79 | ) 80 | fixed_sdp = fix_mozilla_sdp_answer(firefox_offer_sdp, answer_sdp) 81 | assert fixed_sdp == expected_answer_sdp 82 | 83 | fixed_sdp = fix_mozilla_sdp_answer(chrome_offer_sdp, answer_sdp) 84 | assert fixed_sdp == answer_sdp 85 | 86 | 87 | def test_get_media_direction() -> None: 88 | """Test getting the direction in the SDP.""" 89 | sdp = ( 90 | "v=0\r\n" 91 | "o=- 123456 654321 IN IP4 127.0.0.1\r\n" 92 | "s=Test\r\n" 93 | "c=IN IP4 127.0.0.1\r\n" 94 | "t=0 0\r\n" 95 | "m=audio 49170 RTP/AVP 0\r\n" 96 | "a=rtpmap:0 PCMU/8000\r\n" 97 | "a=sendrecv\r\n" 98 | "m=video 51372 RTP/AVP 96\r\n" 99 | "a=rtpmap:96 H264/90000\r\n" 100 | "a=sendonly\r\n" 101 | ) 102 | 103 | direction = _get_media_direction(sdp, SDPMediaKind.AUDIO) 104 | assert direction == SDPDirection.SENDRECV 105 | direction = _get_media_direction(sdp, SDPMediaKind.VIDEO) 106 | assert direction == SDPDirection.SENDONLY 107 | direction = _get_media_direction(sdp, SDPMediaKind.APPLICATION) 108 | assert direction is None 109 | 110 | 111 | def test_update_direction_in_answer() -> None: 112 | """Test updating the direction in the SDP answer.""" 113 | original_sdp = ( 114 | "v=0\r\n" 115 | "o=- 123456 654321 IN IP4 127.0.0.1\r\n" 116 | "s=Test\r\n" 117 | "c=IN IP4 127.0.0.1\r\n" 118 | "t=0 0\r\n" 119 | "m=audio 49170 RTP/AVP 0\r\n" 120 | "a=rtpmap:0 PCMU/8000\r\n" 121 | "a=sendrecv\r\n" # Existing direction 122 | "m=video 51372 RTP/AVP 96\r\n" 123 | "a=rtpmap:96 H264/90000\r\n" 124 | "a=sendrecv\r\n" 125 | ) 126 | 127 | # Expected result after changing the audio direction 128 | expected_sdp = ( 129 | "v=0\r\n" 130 | "o=- 123456 654321 IN IP4 127.0.0.1\r\n" 131 | "s=Test\r\n" 132 | "c=IN IP4 127.0.0.1\r\n" 133 | "t=0 0\r\n" 134 | "m=audio 49170 RTP/AVP 0\r\n" 135 | "a=rtpmap:0 PCMU/8000\r\n" 136 | "a=sendonly\r\n" # Updated direction 137 | "m=video 51372 RTP/AVP 96\r\n" 138 | "a=rtpmap:96 H264/90000\r\n" 139 | "a=sendrecv\r\n" 140 | ) 141 | 142 | new_sdp = _update_direction_in_answer( 143 | original_sdp, SDPMediaKind.AUDIO, SDPDirection.SENDRECV, SDPDirection.SENDONLY 144 | ) 145 | 146 | assert new_sdp == expected_sdp 147 | 148 | 149 | def test_add_foundation_to_candidates() -> None: 150 | """Test adding a foundation value to ICE candidates.""" 151 | original_sdp = ( 152 | "v=0\r\n" 153 | "o=- 123456 654321 IN IP4 127.0.0.1\r\n" 154 | "s=Test\r\n" 155 | "c=IN IP4 127.0.0.1\r\n" 156 | "t=0 0\r\n" 157 | "a=candidate: 1 UDP 2122260223 192.168.0.1 49170 typ host\r\n" 158 | "a=candidate: 1 UDP 2122260223 192.168.0.1 51372 typ host\r\n" 159 | ) 160 | 161 | expected_sdp = ( 162 | "v=0\r\n" 163 | "o=- 123456 654321 IN IP4 127.0.0.1\r\n" 164 | "s=Test\r\n" 165 | "c=IN IP4 127.0.0.1\r\n" 166 | "t=0 0\r\n" 167 | "a=candidate:1 1 UDP 2122260223 192.168.0.1 49170 typ host\r\n" 168 | "a=candidate:2 1 UDP 2122260223 192.168.0.1 51372 typ host\r\n" 169 | ) 170 | 171 | new_sdp = _add_foundation_to_candidates(original_sdp) 172 | 173 | assert new_sdp == expected_sdp 174 | --------------------------------------------------------------------------------