├── .devcontainer ├── devcontainer.json └── podman │ └── devcontainer.json ├── .github ├── dependabot.yml └── workflows │ ├── automatic_generation.yml │ ├── generate_package.yml │ ├── make_release.yml │ ├── publish.yml │ └── pytest.yml ├── .gitignore ├── .gitpod.Dockerfile ├── .gitpod.yml ├── CHANGELOG.md ├── LICENSE ├── LICENSE_HA_CORE.md ├── README.md ├── custom_components ├── __init__.py └── simple_integration │ ├── __init__.py │ ├── config_flow.py │ ├── const.py │ ├── diagnostics.py │ ├── manifest.json │ ├── sensor.py │ ├── strings.json │ └── translations │ └── en.json ├── generate_phacc ├── __init__.py ├── const.py ├── generate_phacc.py └── ha.py ├── ha_version ├── requirements_dev.txt ├── requirements_generate.txt ├── requirements_test.txt ├── setup.cfg ├── setup.py ├── src └── pytest_homeassistant_custom_component │ ├── __init__.py │ ├── asyncio_legacy.py │ ├── common.py │ ├── components │ ├── __init__.py │ ├── diagnostics │ │ └── __init__.py │ └── recorder │ │ ├── __init__.py │ │ ├── common.py │ │ └── db_schema_0.py │ ├── const.py │ ├── ignore_uncaught_exceptions.py │ ├── patch_json.py │ ├── patch_recorder.py │ ├── patch_time.py │ ├── plugins.py │ ├── syrupy.py │ ├── test_util │ ├── __init__.py │ └── aiohttp.py │ ├── testing_config │ ├── __init__.py │ └── custom_components │ │ ├── __init__.py │ │ └── test_constant_deprecation │ │ └── __init__.py │ └── typing.py ├── tests ├── __init__.py ├── conftest.py ├── fixtures │ ├── test_array.json │ └── test_data.json ├── snapshots │ └── test_diagnostics.ambr ├── test_common.py ├── test_config_flow.py ├── test_diagnostics.py └── test_sensor.py └── version /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/python:3.13", 3 | "postCreateCommand": "pip3 install -r requirements_generate.txt" 4 | } 5 | -------------------------------------------------------------------------------- /.devcontainer/podman/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/python:3.13", 3 | "postCreateCommand": "pip install --upgrade pip && pip3 install -r requirements_generate.txt", 4 | // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 5 | "runArgs": [ 6 | "--userns=keep-id" 7 | ], 8 | "containerUser": "vscode", 9 | "updateRemoteUserUID": true, 10 | "containerEnv": { 11 | "HOME": "/home/vscode" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Set update schedule for GitHub Actions 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | -------------------------------------------------------------------------------- /.github/workflows/automatic_generation.yml: -------------------------------------------------------------------------------- 1 | name: Automatic Generate 2 | 3 | on: 4 | schedule: 5 | - cron: "0 5 * * *" 6 | workflow_dispatch: 7 | 8 | jobs: 9 | generate_package: 10 | runs-on: "ubuntu-latest" 11 | outputs: 12 | current_ha_version: ${{ steps.current-ha-version.outputs.current-ha-version }} 13 | new_ha_version: ${{ steps.new-ha-version.outputs.new-ha-version }} 14 | need_to_release: ${{ steps.need-to-release.outputs.need-to-release }} 15 | steps: 16 | - name: checkout repo content 17 | uses: actions/checkout@v4 18 | - name: store current ha version 19 | id: current-ha-version 20 | run: echo "::set-output name=current-ha-version::$(cat ha_version)" 21 | - name: setup python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: '3.13' 25 | - name: install dependencies 26 | run: pip install -r requirements_generate.txt 27 | - name: Install phacc for current versions 28 | run: pip install -e . 29 | - name: execute generate package 30 | run: | 31 | export PYTHONPATH=$PYTHONPATH:$(pwd) 32 | python generate_phacc/generate_phacc.py 33 | - name: store new ha version 34 | id: new-ha-version 35 | run: echo "::set-output name=new-ha-version::$(cat ha_version)" 36 | - name: check need to release 37 | id: need-to-release 38 | run: | 39 | if [[ "${{ steps.current-ha-version.outputs.current-ha-version}}" == "${{ steps.new-ha-version.outputs.new-ha-version }}" ]]; then 40 | echo "::set-output name=need-to-release::false" 41 | else 42 | echo "::set-output name=need-to-release::true" 43 | fi 44 | - name: list files 45 | run: ls -a 46 | - name: publish artifact 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: generated-package 50 | path: | 51 | ./ 52 | !**/*.pyc 53 | !tmp_dir/ 54 | !.git/ 55 | if-no-files-found: error 56 | test: 57 | needs: generate_package 58 | runs-on: "ubuntu-latest" 59 | if: needs.generate_package.outputs.need_to_release == 'true' 60 | strategy: 61 | matrix: 62 | python-version: ['3.13'] 63 | steps: 64 | - name: checkout repo content 65 | uses: actions/checkout@v4 66 | - name: download artifact 67 | uses: actions/download-artifact@v4 68 | with: 69 | name: generated-package 70 | - name: Set up Python ${{ matrix.python-version }} 71 | uses: actions/setup-python@v5 72 | with: 73 | python-version: ${{ matrix.python-version }} 74 | - name: Install dependencies 75 | run: | 76 | python -m pip install --upgrade pip 77 | pip install -e . 78 | - name: Test with pytest 79 | run: | 80 | pytest 81 | make_release: 82 | needs: [generate_package, test] 83 | runs-on: "ubuntu-latest" 84 | if: needs.generate_package.outputs.need_to_release == 'true' 85 | steps: 86 | - uses: actions/checkout@v4 87 | - name: download artifact 88 | uses: actions/download-artifact@v4 89 | with: 90 | name: generated-package 91 | - name: need_to_release_print 92 | run: "echo ${{ needs.generate_package.outputs.need_to_release }}" 93 | - id: next_version 94 | uses: zwaldowski/semver-release-action@v4 95 | with: 96 | dry_run: true 97 | bump: patch 98 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 99 | - run: echo "${{ steps.next_version.outputs.version }}" > version 100 | - run: echo "${{ steps.next_version.outputs.version }}" 101 | - id: git_commit 102 | run: | 103 | git config user.name 'Matthew Flamm' 104 | git config user.email 'MatthewFlamm@users.noreply.github.com' 105 | git add . 106 | git commit -m "Bump version" 107 | git push 108 | echo "::set-output name=sha::$(git rev-parse HEAD)" 109 | - uses: zwaldowski/semver-release-action@v4 110 | with: 111 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 112 | sha: ${{ steps.git_commit.outputs.sha }} 113 | - name: Create Release 114 | id: create_release 115 | uses: actions/create-release@v1 116 | env: 117 | GITHUB_TOKEN: ${{ secrets.REPO_SCOPED_TOKEN }} 118 | with: 119 | tag_name: ${{ steps.next_version.outputs.version }} 120 | release_name: Release ${{ steps.next_version.outputs.version }} 121 | body: | 122 | Automatic release 123 | homeassistant version: ${{ needs.generate_package.outputs.new_ha_version }} 124 | draft: false 125 | prerelease: false 126 | - name: Set up Python 127 | uses: actions/setup-python@v5 128 | with: 129 | python-version: '3.13' 130 | - name: Install dependencies 131 | run: | 132 | python -m pip install --upgrade pip 133 | pip install setuptools wheel twine 134 | - name: Build 135 | run: | 136 | python setup.py sdist bdist_wheel 137 | - name: Publish distribution 📦 to PyPI 138 | uses: pypa/gh-action-pypi-publish@release/v1 139 | with: 140 | password: ${{ secrets.PYPI_TOKEN }} 141 | -------------------------------------------------------------------------------- /.github/workflows/generate_package.yml: -------------------------------------------------------------------------------- 1 | name: Generate Package 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | generate_package: 8 | runs-on: "ubuntu-latest" 9 | steps: 10 | - uses: "actions/checkout@v4" 11 | - name: checkout repo content 12 | uses: actions/checkout@v4 13 | - name: setup python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.13' 17 | - name: install dependencies 18 | run: pip install -r requirements_generate.txt 19 | - name: Install phacc for current versions 20 | run: pip install -e . 21 | - name: execute generate package 22 | run: | 23 | export PYTHONPATH=$PYTHONPATH:$(pwd) 24 | python generate_phacc/generate_phacc.py --regen 25 | - name: Create Pull Request 26 | uses: peter-evans/create-pull-request@v7 27 | with: 28 | token: ${{ secrets.REPO_SCOPED_TOKEN }} 29 | -------------------------------------------------------------------------------- /.github/workflows/make_release.yml: -------------------------------------------------------------------------------- 1 | name: Make Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | type: 7 | description: 'Type of version to increment. major, minor, or patch.' 8 | required: true 9 | default: 'patch' 10 | jobs: 11 | make_release: 12 | runs-on: "ubuntu-latest" 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: store current ha version 16 | id: current-ha-version 17 | run: echo "::set-output name=current-ha-version::$(cat ha_version)" 18 | - id: next_version 19 | uses: zwaldowski/semver-release-action@v4 20 | with: 21 | dry_run: true 22 | bump: ${{ github.event.inputs.type }} 23 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 24 | - run: echo "${{ steps.next_version.outputs.version }}" > version 25 | - run: echo "${{ steps.next_version.outputs.version }}" 26 | - id: git_commit 27 | run: | 28 | git config user.name 'Matthew Flamm' 29 | git config user.email 'MatthewFlamm@users.noreply.github.com' 30 | git add . 31 | git commit -m "Bump version" 32 | git push 33 | echo "::set-output name=sha::$(git rev-parse HEAD)" 34 | - uses: zwaldowski/semver-release-action@v4 35 | with: 36 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 37 | sha: ${{ steps.git_commit.outputs.sha }} 38 | - name: Create Release 39 | id: create_release 40 | uses: actions/create-release@v1 41 | env: 42 | GITHUB_TOKEN: ${{ secrets.REPO_SCOPED_TOKEN }} 43 | with: 44 | tag_name: ${{ steps.next_version.outputs.version }} 45 | release_name: Release ${{ steps.next_version.outputs.version }} 46 | body: | 47 | Automatic release 48 | homeassistant version: ${{ steps.current-ha-version.outputs.current-ha-version }} 49 | draft: false 50 | prerelease: false 51 | - name: Set up Python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: '3.13' 55 | - name: Install dependencies 56 | run: | 57 | python -m pip install --upgrade pip 58 | pip install setuptools wheel twine 59 | - name: Build and publish 60 | env: 61 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 62 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 63 | run: | 64 | python setup.py sdist bdist_wheel 65 | twine upload dist/* 66 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: 5 | jobs: 6 | make_release: 7 | runs-on: "ubuntu-latest" 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: '3.13' 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install setuptools wheel twine 18 | - name: Build 19 | run: | 20 | python setup.py sdist bdist_wheel 21 | - name: Publish distribution 📦 to PyPI 22 | uses: pypa/gh-action-pypi-publish@release/v1 23 | with: 24 | password: ${{ secrets.PYPI_TOKEN }} 25 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Pytest 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | schedule: 13 | - cron: "0 5 * * *" 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | python-version: ["3.13"] 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies generate 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements_generate.txt 34 | - name: Install phacc for current versions 35 | run: pip install -e . 36 | - name: execute generate package 37 | run: | 38 | export PYTHONPATH=$PYTHONPATH:$(pwd) 39 | python generate_phacc/generate_phacc.py --regen 40 | - name: list files 41 | run: ls -a 42 | - name: publish artifact 43 | uses: actions/upload-artifact@v4 44 | with: 45 | name: generated-package-${{ matrix.python-version }} 46 | path: | 47 | ./ 48 | !**/*.pyc 49 | !tmp_dir/ 50 | !.git/ 51 | if-no-files-found: error 52 | - name: Install dependencies test 53 | run: | 54 | pip install -e . 55 | - name: Test with pytest 56 | run: | 57 | pytest 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # this packages temporary folder 132 | tmp_dir/ -------------------------------------------------------------------------------- /.gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-python-3.11 2 | 3 | USER gitpod 4 | COPY requirements_generate.txt requirements_generate.txt 5 | RUN pip install -r requirements_generate.txt 6 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | # This configuration file was automatically generated by Gitpod. 2 | # Please adjust to your needs (see https://www.gitpod.io/docs/config-gitpod-file) 3 | # and commit this file to your remote git repository to share the goodness with others. 4 | image: 5 | file: .gitpod.Dockerfile 6 | tasks: 7 | - before: printf 'export PATH="%s:$PATH"\n' "/workspace/pytest-homeassistant-custom-component" >> $HOME/.bashrc && exit 8 | github: 9 | prebuilds: 10 | master: true 11 | branches: false 12 | pullRequests: false 13 | pullRequestsFromForks: false 14 | addCheck: false 15 | addComment: true 16 | addBadge: false 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | This changelog only includes changes directly related to the structure of this project. Changes in testing behavior may still occur from changes in homeassistant/core. 3 | 4 | Changes to minor version indicate a change structurally in this pacakge. Changes in patch indicate changes solely from homeassistant/core. The latter does not imply no breaking changes are introduced. 5 | 6 | ## 0.13.0 7 | * bump minimum Python version to Python 3.10 8 | 9 | ## 0.8.0 10 | * recorder dependencies required for tests 11 | 12 | ## 0.7.0 13 | * paho-mqtt now required for tests 14 | 15 | ## 0.6.0 16 | * Python 3.8 dropped with homeassistant requirement 17 | * Minor change to generation of package for new homassistant code 18 | 19 | ## 0.4.0 20 | * `enable_custom_integrations` now required by ha 21 | * sqlalchemy version now pinned to ha version 22 | 23 | ## 0.3.0 24 | * Generate package only on homeassistant release versions 25 | * Use latest homeassistant release version including beta 26 | * homeassistant/core tags are used to determine latest release 27 | * Pin homeassistant version in requirements 28 | 29 | ## 0.2.0 30 | * fix `load_fixture` 31 | 32 | ## 0.1.0 33 | * remove Python 3.7 and add Python 3.9 34 | * remove `async_test` 35 | * move non-testing dependencies to separate `requirements_dev.txt` 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Matthew Flamm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE_HA_CORE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytest-homeassistant-custom-component 2 | 3 | ![HA core version](https://img.shields.io/static/v1?label=HA+core+version&message=2025.6.0b5&labelColor=blue) 4 | 5 | [![Open in Gitpod](https://gitpod.io/button/open-in-gitpod.svg)](https://gitpod.io/#https://github.com/MatthewFlamm/pytest-homeassistant-custom-component) 6 | 7 | Package to automatically extract testing plugins from Home Assistant for custom component testing. 8 | The goal is to provide the same functionality as the tests in home-assistant/core. 9 | pytest-homeassistant-custom-component is updated daily according to the latest homeassistant release including beta. 10 | 11 | ## Usage: 12 | * All pytest fixtures can be used as normal, like `hass` 13 | * For helpers: 14 | * home-assistant/core native test: `from tests.common import MockConfigEntry` 15 | * custom component test: `from pytest_homeassistant_custom_component.common import MockConfigEntry` 16 | * If your integration is inside a `custom_components` folder, a `custom_components/__init__.py` file or changes to `sys.path` may be required. 17 | * `enable_custom_integrations` fixture is required (versions >=2021.6.0b0) 18 | * Some fixtures, e.g. `recorder_mock`, need to be initialized before `enable_custom_integrations`. See https://github.com/MatthewFlamm/pytest-homeassistant-custom-component/issues/132. 19 | * pytest-asyncio might now require `asyncio_mode = auto` config, see #129. 20 | * If using `load_fixture`, the files need to be in a `fixtures` folder colocated with the tests. For example, a test in `test_sensor.py` can load data from `some_data.json` using `load_fixture` from this structure: 21 | 22 | ``` 23 | tests/ 24 | fixtures/ 25 | some_data.json 26 | test_sensor.py 27 | ``` 28 | 29 | * When using syrupy snapshots, add a `snapshot` fixture to conftest.py to make sure the snapshots are loaded from snapshot folder colocated with the tests. 30 | 31 | ```py 32 | from pytest_homeassistant_custom_component.syrupy import HomeAssistantSnapshotExtension 33 | from syrupy.assertion import SnapshotAssertion 34 | 35 | 36 | @pytest.fixture 37 | def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: 38 | """Return snapshot assertion fixture with the Home Assistant extension.""" 39 | return snapshot.use_extension(HomeAssistantSnapshotExtension) 40 | ``` 41 | 42 | ## Examples: 43 | * See [list of custom components](https://github.com/MatthewFlamm/pytest-homeassistant-custom-component/network/dependents) as examples that use this package. 44 | * Also see tests for `simple_integration` in this repository. 45 | * Use [cookiecutter-homeassistant-custom-component](https://github.com/oncleben31/cookiecutter-homeassistant-custom-component) to create a custom component with tests by using [cookiecutter](https://github.com/cookiecutter/cookiecutter). 46 | * The [github-custom-component-tutorial](https://github.com/boralyl/github-custom-component-tutorial) explaining in details how to create a custom componenent with a test suite using this package. 47 | 48 | ## More Info 49 | This repository is set up to be nearly fully automatic. 50 | 51 | * Version of home-assistant/core is given in `ha_version`, `pytest_homeassistant_custom_component.const`, and in the README above. 52 | * This package is generated against published releases of homeassistant and updated daily. 53 | * PRs should not include changes to the `pytest_homeassistant_custom_component` files. CI testing will automatically generate the new files. 54 | 55 | ### Version Strategy 56 | * When changes in extraction are required, there will be a change in the minor version. 57 | * A change in the patch version indicates that it was an automatic update with a homeassistant version. 58 | * This enables tracking back to which versions of pytest-homeassistant-custom-component can be used for 59 | extracting testing utilities from which version of homeassistant. 60 | 61 | This package was inspired by [pytest-homeassistant](https://github.com/boralyl/pytest-homeassistant) by @boralyl, but is intended to more closely and automatically track the home-assistant/core library. 62 | -------------------------------------------------------------------------------- /custom_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatthewFlamm/pytest-homeassistant-custom-component/6642e5a77ac1b7156eef094387e65fed39fce116/custom_components/__init__.py -------------------------------------------------------------------------------- /custom_components/simple_integration/__init__.py: -------------------------------------------------------------------------------- 1 | """The Simple Integration integration.""" 2 | 3 | from homeassistant.config_entries import ConfigEntry 4 | from homeassistant.core import HomeAssistant 5 | 6 | 7 | PLATFORMS = ["sensor"] 8 | 9 | 10 | async def async_setup(hass: HomeAssistant, config: dict): 11 | """Set up the Simple Integration component.""" 12 | return True 13 | 14 | 15 | async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): 16 | """Set up Simple Integration from a config entry.""" 17 | 18 | await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) 19 | 20 | return True 21 | 22 | 23 | async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): 24 | """Unload a config entry.""" 25 | 26 | unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) 27 | 28 | return unload_ok 29 | -------------------------------------------------------------------------------- /custom_components/simple_integration/config_flow.py: -------------------------------------------------------------------------------- 1 | """Config flow for Simple Integration integration.""" 2 | import logging 3 | 4 | import voluptuous as vol 5 | 6 | from homeassistant import config_entries, core, exceptions 7 | 8 | from .const import DOMAIN # pylint:disable=unused-import 9 | 10 | _LOGGER = logging.getLogger(__name__) 11 | 12 | DATA_SCHEMA = vol.Schema({"name": str,}) 13 | 14 | 15 | class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): 16 | """Handle a config flow for Simple Integration.""" 17 | 18 | VERSION = 1 19 | 20 | CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL 21 | 22 | async def async_step_user(self, user_input=None): 23 | """Handle the initial step.""" 24 | errors = {} 25 | if user_input is not None: 26 | try: 27 | return self.async_create_entry(title=user_input["name"], data=user_input) 28 | except Exception: # pylint: disable=broad-except 29 | _LOGGER.exception("Unexpected exception") 30 | errors["base"] = "unknown" 31 | 32 | return self.async_show_form( 33 | step_id="user", data_schema=DATA_SCHEMA, errors=errors 34 | ) 35 | -------------------------------------------------------------------------------- /custom_components/simple_integration/const.py: -------------------------------------------------------------------------------- 1 | """Constants for the Simple Integration integration.""" 2 | 3 | DOMAIN = "simple_integration" 4 | -------------------------------------------------------------------------------- /custom_components/simple_integration/diagnostics.py: -------------------------------------------------------------------------------- 1 | """diagnostics for Simple Integration integration.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from homeassistant.components.diagnostics import async_redact_data 8 | from homeassistant.config_entries import ConfigEntry 9 | 10 | from homeassistant.core import HomeAssistant 11 | 12 | TO_REDACT = {} 13 | 14 | 15 | async def async_get_config_entry_diagnostics( 16 | hass: HomeAssistant, entry: ConfigEntry 17 | ) -> dict[str, Any]: 18 | """Return diagnostics for a config entry.""" 19 | 20 | diagnostic_data: dict[str, Any] = { 21 | "config_entry": async_redact_data(entry.as_dict(), TO_REDACT), 22 | } 23 | 24 | return diagnostic_data 25 | -------------------------------------------------------------------------------- /custom_components/simple_integration/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "domain": "simple_integration", 3 | "name": "Simple Integration", 4 | "config_flow": true, 5 | "documentation": "NoWhere", 6 | "version": "0.0.0", 7 | "codeowners": [ 8 | "@NoOne" 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /custom_components/simple_integration/sensor.py: -------------------------------------------------------------------------------- 1 | """Platform for sensor integration.""" 2 | from homeassistant.const import UnitOfTemperature 3 | from homeassistant.helpers.entity import Entity 4 | 5 | 6 | async def async_setup_entry(hass, config_entry, async_add_devices): 7 | """Set up entry.""" 8 | async_add_devices([ExampleSensor(),]) 9 | 10 | 11 | class ExampleSensor(Entity): 12 | """Representation of a Sensor.""" 13 | 14 | def __init__(self): 15 | """Initialize the sensor.""" 16 | self._state = 23 17 | 18 | @property 19 | def should_poll(self): 20 | """Whether entity polls.""" 21 | return False 22 | 23 | @property 24 | def name(self): 25 | """Return the name of the sensor.""" 26 | return 'Example Temperature' 27 | 28 | @property 29 | def state(self): 30 | """Return the state of the sensor.""" 31 | return self._state 32 | 33 | @property 34 | def unit_of_measurement(self): 35 | """Return the unit of measurement.""" 36 | return UnitOfTemperature.CELSIUS 37 | -------------------------------------------------------------------------------- /custom_components/simple_integration/strings.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Simple Integration", 3 | "config": { 4 | "step": { 5 | "user": { 6 | "data": { 7 | "name": "Name" 8 | } 9 | } 10 | }, 11 | "error": { 12 | "unknown": "[%key:common::config_flow::error::unknown%]" 13 | }, 14 | "abort": { 15 | "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /custom_components/simple_integration/translations/en.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Simple Integration", 3 | "config": { 4 | "step": { 5 | "user": { 6 | "data": { 7 | "name": "Name" 8 | } 9 | } 10 | }, 11 | "error": { 12 | "unknown": "[%key:common::config_flow::error::unknown%]" 13 | }, 14 | "abort": { 15 | "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /generate_phacc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MatthewFlamm/pytest-homeassistant-custom-component/6642e5a77ac1b7156eef094387e65fed39fce116/generate_phacc/__init__.py -------------------------------------------------------------------------------- /generate_phacc/const.py: -------------------------------------------------------------------------------- 1 | """Constants for phacc.""" 2 | TMP_DIR = "tmp_dir" 3 | PACKAGE_DIR = "src/pytest_homeassistant_custom_component" 4 | REQUIREMENTS_FILE = "requirements_test.txt" 5 | CONST_FILE = "const.py" 6 | 7 | REQUIREMENTS_FILE_DEV = "requirements_dev.txt" 8 | 9 | path = "." 10 | clone = "git clone https://github.com/home-assistant/core.git tmp_dir" 11 | diff = "git diff --exit-code" 12 | 13 | files = [ 14 | "__init__.py", 15 | "common.py", 16 | "conftest.py", 17 | "ignore_uncaught_exceptions.py", 18 | "components/recorder/common.py", 19 | "patch_time.py", 20 | "syrupy.py", 21 | "typing.py", 22 | "patch_json.py", 23 | "patch_recorder.py", 24 | ] 25 | 26 | # remove requirements for development only, i.e not related to homeassistant tests 27 | requirements_remove = [ 28 | "codecov", 29 | "mypy", 30 | "mypy-dev", 31 | "pre-commit", 32 | "pylint", 33 | "astroid", 34 | ] 35 | 36 | LICENSE_FILE_HA = "LICENSE.md" 37 | LICENSE_FILE_NEW = "LICENSE_HA_CORE.md" 38 | 39 | HA_VERSION_FILE = "ha_version" 40 | -------------------------------------------------------------------------------- /generate_phacc/generate_phacc.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import pathlib 4 | import re 5 | import shutil 6 | import os 7 | 8 | import click 9 | import git 10 | 11 | from ha import prepare_homeassistant 12 | from const import ( 13 | TMP_DIR, 14 | PACKAGE_DIR, 15 | REQUIREMENTS_FILE, 16 | CONST_FILE, 17 | REQUIREMENTS_FILE_DEV, 18 | LICENSE_FILE_HA, 19 | LICENSE_FILE_NEW, 20 | path, 21 | files, 22 | requirements_remove, 23 | HA_VERSION_FILE, 24 | ) 25 | 26 | @click.command 27 | @click.option("--regen/--no-regen", default=False, help="Whether to regenerate despite version") 28 | def cli(regen): 29 | if os.path.isdir(PACKAGE_DIR): 30 | shutil.rmtree(PACKAGE_DIR) 31 | if os.path.isfile(REQUIREMENTS_FILE): 32 | os.remove(REQUIREMENTS_FILE) 33 | 34 | ha_version = prepare_homeassistant() 35 | 36 | with open(HA_VERSION_FILE, "r") as f: 37 | current_version = f.read() 38 | print(f"Current Version: {current_version}") 39 | 40 | 41 | def process_files(): 42 | os.mkdir(PACKAGE_DIR) 43 | os.mkdir(os.path.join(PACKAGE_DIR, "test_util")) 44 | os.makedirs(os.path.join(PACKAGE_DIR, "components", "recorder")) 45 | os.makedirs(os.path.join(PACKAGE_DIR, "components", "diagnostics")) 46 | os.makedirs(os.path.join(PACKAGE_DIR, "testing_config", "custom_components", "test_constant_deprecation")) 47 | shutil.copy2(os.path.join(TMP_DIR, REQUIREMENTS_FILE), REQUIREMENTS_FILE) 48 | shutil.copy2( 49 | os.path.join(TMP_DIR, "homeassistant", CONST_FILE), 50 | os.path.join(PACKAGE_DIR, CONST_FILE), 51 | ) 52 | shutil.copy2( 53 | os.path.join(TMP_DIR, "tests", "test_util", "aiohttp.py"), 54 | os.path.join(PACKAGE_DIR, "test_util", "aiohttp.py"), 55 | ) 56 | shutil.copy2( 57 | os.path.join(TMP_DIR, "tests", "test_util", "__init__.py"), 58 | os.path.join(PACKAGE_DIR, "test_util", "__init__.py"), 59 | ) 60 | shutil.copy2( 61 | os.path.join(TMP_DIR, "tests", "components", "recorder", "common.py"), 62 | os.path.join(PACKAGE_DIR, "components", "recorder", "common.py"), 63 | ) 64 | shutil.copy2( 65 | os.path.join(TMP_DIR, "tests", "components", "recorder", "db_schema_0.py"), 66 | os.path.join(PACKAGE_DIR, "components", "recorder", "db_schema_0.py"), 67 | ) 68 | shutil.copy2( 69 | os.path.join(TMP_DIR, "tests", "components", "recorder", "__init__.py"), 70 | os.path.join(PACKAGE_DIR, "components", "recorder", "__init__.py"), 71 | ) 72 | shutil.copy2( 73 | os.path.join(TMP_DIR, "tests", "components", "diagnostics", "__init__.py"), 74 | os.path.join(PACKAGE_DIR, "components", "diagnostics", "__init__.py"), 75 | ) 76 | shutil.copy2( 77 | os.path.join(TMP_DIR, "tests", "components", "__init__.py"), 78 | os.path.join(PACKAGE_DIR, "components", "__init__.py"), 79 | ) 80 | shutil.copy2( 81 | os.path.join(TMP_DIR, "tests", "testing_config", "__init__.py"), 82 | os.path.join(PACKAGE_DIR, "testing_config", "__init__.py"), 83 | ) 84 | shutil.copy2( 85 | os.path.join(TMP_DIR, "tests", "testing_config", "custom_components", "__init__.py"), 86 | os.path.join(PACKAGE_DIR, "testing_config", "custom_components", "__init__.py"), 87 | ) 88 | shutil.copy2( 89 | os.path.join(TMP_DIR, "tests", "testing_config", "custom_components", "test_constant_deprecation", "__init__.py"), 90 | os.path.join(PACKAGE_DIR, "testing_config", "custom_components", "test_constant_deprecation", "__init__.py"), 91 | ) 92 | shutil.copy2( 93 | os.path.join(TMP_DIR, LICENSE_FILE_HA), 94 | LICENSE_FILE_NEW, 95 | ) 96 | 97 | for f in files: 98 | shutil.copy2(os.path.join(TMP_DIR, "tests", f), os.path.join(PACKAGE_DIR, f)) 99 | 100 | filename = os.path.join(PACKAGE_DIR, f) 101 | 102 | with open(filename, "r") as file: 103 | filedata = file.read() 104 | 105 | filedata = filedata.replace( 106 | "tests.", "." * (f.count("/") + 1) 107 | ) # Add dots depending on depth 108 | 109 | with open(filename, "w") as file: 110 | file.write(filedata) 111 | 112 | os.rename( 113 | os.path.join(PACKAGE_DIR, "conftest.py"), 114 | os.path.join(PACKAGE_DIR, "plugins.py"), 115 | ) 116 | 117 | with open(os.path.join(PACKAGE_DIR, CONST_FILE), "r") as original_file: 118 | data = original_file.readlines() 119 | new_data = [d for d in data[:100] if "version" in d.lower() or "from typing" in d] 120 | new_data.insert(0, data[0]) 121 | 122 | with open(os.path.join(PACKAGE_DIR, CONST_FILE), "w") as new_file: 123 | new_file.write("".join(new_data)) 124 | 125 | added_text = "This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component.\n" 126 | triple_quote = '"""\n' 127 | 128 | for f in pathlib.Path(PACKAGE_DIR).rglob("*.py"): 129 | with open(f, "r") as original_file: 130 | data = original_file.readlines() 131 | 132 | multiline_docstring = not data[0].endswith(triple_quote) 133 | line_after_docstring = 1 134 | old_docstring = "" 135 | if not multiline_docstring: 136 | old_docstring = data[0][3:][:-4] 137 | else: 138 | old_docstring = data[0][3:] 139 | while data[line_after_docstring] != triple_quote: 140 | old_docstring += data[line_after_docstring] 141 | line_after_docstring += 1 142 | line_after_docstring += 1 # Skip last triplequote 143 | 144 | new_docstring = f"{triple_quote}{old_docstring}\n\n{added_text}{triple_quote}" 145 | body = "".join(data[line_after_docstring:]) 146 | with open(f, "w") as new_file: 147 | new_file.write("".join([new_docstring, body])) 148 | 149 | added_text = "# This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component.\n" 150 | 151 | with open(REQUIREMENTS_FILE, "r") as original_file: 152 | data = original_file.readlines() 153 | 154 | def is_test_requirement(requirement): 155 | # if == not in d this is either a comment or unkown package, include 156 | if "==" not in requirement: 157 | return True 158 | 159 | regex = re.compile("types-.+") 160 | if re.match(regex, requirement): 161 | return False 162 | 163 | if d.split("==")[0] in requirements_remove: 164 | return False 165 | 166 | return True 167 | 168 | new_data = [] 169 | removed_data = [] 170 | for d in data: 171 | if is_test_requirement(d): 172 | new_data.append(d) 173 | else: 174 | removed_data.append(d) 175 | new_data.append(f"homeassistant=={ha_version}\n") 176 | new_data.insert(0, added_text) 177 | 178 | def find_dependency(dependency, data): 179 | for d in data: 180 | if dependency in d.lower(): 181 | return d 182 | raise ValueError(f"could not find {dependency}") 183 | 184 | with open(os.path.join(TMP_DIR, "requirements_all.txt"), "r") as f: 185 | data = f.readlines() 186 | 187 | def add_dependency(dependency, ha_data, new_data): 188 | dep = find_dependency(dependency, data) 189 | if not "\n" == dep[-2:]: 190 | dep = f"{dep}\n" 191 | new_data.append(dep) 192 | 193 | add_dependency("sqlalchemy", data, new_data) 194 | add_dependency("paho-mqtt", data, new_data) 195 | add_dependency("numpy", data, new_data) 196 | 197 | removed_data.insert(0, added_text) 198 | 199 | with open(REQUIREMENTS_FILE, "w") as new_file: 200 | new_file.writelines(new_data) 201 | 202 | with open(REQUIREMENTS_FILE_DEV, "w") as new_file: 203 | new_file.writelines(removed_data) 204 | 205 | from pytest_homeassistant_custom_component.const import __version__ 206 | 207 | with open("README.md", "r") as original_file: 208 | data = original_file.readlines() 209 | 210 | data[ 211 | 2 212 | ] = f"![HA core version](https://img.shields.io/static/v1?label=HA+core+version&message={__version__}&labelColor=blue)\n" 213 | 214 | with open("README.md", "w") as new_file: 215 | new_file.write("".join(data)) 216 | 217 | print(f"New Version: {__version__}") 218 | 219 | # modify load_fixture 220 | with open(os.path.join(PACKAGE_DIR, "common.py"), "r") as original_file: 221 | data = original_file.readlines() 222 | 223 | import_time_lineno = [i for i, line in enumerate(data) if "import time" in line] 224 | assert len(import_time_lineno) == 1 225 | data.insert(import_time_lineno[0] + 1, "import traceback\n") 226 | 227 | fixture_path_lineno = [ 228 | i for i, line in enumerate(data) if "def get_fixture_path" in line 229 | ] 230 | assert len(fixture_path_lineno) == 1 231 | data.insert( 232 | fixture_path_lineno[0] + 2, 233 | " start_path = (current_file := traceback.extract_stack()[idx:=-1].filename)\n", 234 | ) 235 | data.insert( 236 | fixture_path_lineno[0] + 3, 237 | " while start_path == current_file:\n", 238 | ) 239 | data.insert( 240 | fixture_path_lineno[0] + 4, 241 | " start_path = traceback.extract_stack()[idx:=idx-1].filename\n", 242 | ) 243 | data[fixture_path_lineno[0] + 9] = data[fixture_path_lineno[0] + 9].replace( 244 | "__file__", "start_path" 245 | ) 246 | data[fixture_path_lineno[0] + 11] = data[fixture_path_lineno[0] + 11].replace( 247 | "__file__", "start_path" 248 | ) 249 | 250 | with open(os.path.join(PACKAGE_DIR, "common.py"), "w") as new_file: 251 | new_file.writelines(data) 252 | 253 | # modify diagnostics file 254 | with open(os.path.join(PACKAGE_DIR, "components", "diagnostics", "__init__.py"), "r") as original_file: 255 | data = original_file.readlines() 256 | 257 | diagnostics_lineno = [ 258 | i for i, line in enumerate(data) if "from tests.typing" in line 259 | ] 260 | assert len(diagnostics_lineno) == 1 261 | data[diagnostics_lineno[0]] = data[diagnostics_lineno[0]].replace( 262 | "tests.typing","pytest_homeassistant_custom_component.typing" 263 | ) 264 | 265 | with open(os.path.join(PACKAGE_DIR, "components", "diagnostics", "__init__.py"), "w") as new_file: 266 | new_file.writelines(data) 267 | 268 | 269 | if ha_version != current_version or regen: 270 | process_files() 271 | with open(HA_VERSION_FILE, "w") as f: 272 | f.write(ha_version) 273 | else: 274 | print("Already up to date") 275 | 276 | if __name__=="__main__": 277 | cli() -------------------------------------------------------------------------------- /generate_phacc/ha.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import git 4 | 5 | from const import clone, TMP_DIR 6 | 7 | 8 | class HAVersion: 9 | def __init__(self, version): 10 | self._version = version 11 | split_version = version.split(".") 12 | try: 13 | self.major = int(split_version[0]) 14 | except ValueError: 15 | self.major = 0 16 | self.minor = 0 17 | self.patch = 0 18 | self.beta = None 19 | if self.major < 2021: 20 | return 21 | 22 | if len(split_version)>=2: 23 | self.minor = int(split_version[1]) 24 | if len(split_version)>=3: 25 | patch = split_version[2].split("b") 26 | self.patch = int(patch[0]) 27 | if len(patch) == 2: 28 | self.beta = int(patch[1]) 29 | 30 | 31 | def __eq__(self, other): 32 | if ( 33 | self.major==other.major 34 | and self.minor==self.minor 35 | and self.patch==self.patch 36 | and self.beta==self.beta 37 | ): 38 | return True 39 | return False 40 | 41 | 42 | def __gt__(self, other): 43 | if self.major > other.major: 44 | return True 45 | elif self.major < other.major: 46 | return False 47 | 48 | 49 | if self.minor > other.minor: 50 | return True 51 | elif self.minor < other.minor: 52 | return False 53 | 54 | 55 | if self.patch > other.patch: 56 | return True 57 | elif self.patch < other.patch: 58 | return False 59 | 60 | 61 | if self.beta is not None and other.beta is None: 62 | return False 63 | elif self.beta is None and other.beta is not None: 64 | return True 65 | elif self.beta is None and other.beta is None: 66 | return False 67 | elif self.beta > other.beta: 68 | return True 69 | return False 70 | 71 | def prepare_homeassistant(ref=None): 72 | if not os.path.isdir(TMP_DIR): 73 | os.system(clone) # Cloning 74 | 75 | if ref is None: 76 | repo = git.Repo(TMP_DIR) 77 | versions = {str(tag): HAVersion(str(tag)) for tag in repo.tags} 78 | latest_version = HAVersion("0.0.0") 79 | for key, version in versions.items(): 80 | if version > latest_version: 81 | latest_version = version 82 | ref = key 83 | 84 | repo.head.reference = repo.refs[ref] 85 | repo.head.reset(index=True, working_tree=True) 86 | return ref 87 | -------------------------------------------------------------------------------- /ha_version: -------------------------------------------------------------------------------- 1 | 2025.6.0b5 -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | # This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 2 | astroid==3.3.10 3 | mypy-dev==1.16.0a8 4 | pre-commit==4.0.0 5 | pylint==3.3.7 6 | types-aiofiles==24.1.0.20250326 7 | types-atomicwrites==1.4.5.1 8 | types-croniter==6.0.0.20250411 9 | types-caldav==1.3.0.20241107 10 | types-chardet==0.1.5 11 | types-decorator==5.2.0.20250324 12 | types-pexpect==4.9.0.20241208 13 | types-protobuf==5.29.1.20250403 14 | types-psutil==7.0.0.20250401 15 | types-pyserial==3.5.0.20250326 16 | types-python-dateutil==2.9.0.20241206 17 | types-python-slugify==8.0.2.20240310 18 | types-pytz==2025.2.0.20250326 19 | types-PyYAML==6.0.12.20250402 20 | types-requests==2.31.0.3 21 | types-xmltodict==0.13.0.3 22 | -------------------------------------------------------------------------------- /requirements_generate.txt: -------------------------------------------------------------------------------- 1 | click==8.1.3 2 | GitPython==3.1.14 3 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | # This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 2 | # linters such as pylint should be pinned, as new releases 3 | # make new things fail. Manually update these pins when pulling in a 4 | # new version 5 | 6 | # types-* that have versions roughly corresponding to the packages they 7 | # contain hints for available should be kept in sync with them 8 | 9 | -c homeassistant/package_constraints.txt 10 | -r requirements_test_pre_commit.txt 11 | coverage==7.6.12 12 | freezegun==1.5.1 13 | go2rtc-client==0.2.1 14 | license-expression==30.4.1 15 | mock-open==1.4.0 16 | pydantic==2.11.3 17 | pylint-per-file-ignores==1.4.0 18 | pipdeptree==2.26.1 19 | pytest-asyncio==0.26.0 20 | pytest-aiohttp==1.1.0 21 | pytest-cov==6.0.0 22 | pytest-freezer==0.4.9 23 | pytest-github-actions-annotate-failures==0.3.0 24 | pytest-socket==0.7.0 25 | pytest-sugar==1.0.0 26 | pytest-timeout==2.3.1 27 | pytest-unordered==0.6.1 28 | pytest-picked==0.5.1 29 | pytest-xdist==3.6.1 30 | pytest==8.3.5 31 | requests-mock==1.12.1 32 | respx==0.22.0 33 | syrupy==4.8.1 34 | tqdm==4.67.1 35 | homeassistant==2025.6.0b5 36 | SQLAlchemy==2.0.40 37 | 38 | paho-mqtt==2.1.0 39 | 40 | numpy==2.2.2 41 | 42 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = tests 3 | asyncio_mode = auto 4 | asyncio_default_fixture_loop_scope = function -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from setuptools import setup, find_packages 4 | 5 | requirements = [ 6 | "sqlalchemy", 7 | ] 8 | with open("requirements_test.txt","r") as f: 9 | for line in f: 10 | if "txt" not in line and "#" not in line: 11 | requirements.append(line) 12 | 13 | with open("version", "r") as f: 14 | __version__ = f.read() 15 | 16 | setup( 17 | author="Matthew Flamm", 18 | name="pytest-homeassistant-custom-component", 19 | version=__version__, 20 | packages=find_packages(where="src"), 21 | package_dir={"": "src"}, 22 | python_requires=">=3.13", 23 | install_requires=requirements, 24 | license="MIT license", 25 | url="https://github.com/MatthewFlamm/pytest-homeassistant-custom-component", 26 | author_email="matthewflamm0@gmail.com", 27 | description="Experimental package to automatically extract test plugins for Home Assistant custom components", 28 | long_description=open('README.md').read(), 29 | long_description_content_type='text/markdown', 30 | classifiers=[ 31 | "Development Status :: 3 - Alpha", 32 | "Framework :: Pytest", 33 | "Intended Audience :: Developers", 34 | "License :: OSI Approved :: MIT License", 35 | "Programming Language :: Python", 36 | "Programming Language :: Python :: 3.13", 37 | "Topic :: Software Development :: Testing", 38 | ], 39 | entry_points={"pytest11": ["homeassistant = pytest_homeassistant_custom_component.plugins"]}, 40 | ) 41 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for Home Assistant. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/asyncio_legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal legacy asyncio.coroutine. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | # flake8: noqa 8 | # stubbing out for integrations that have 9 | # not yet been updated for python 3.11 10 | # but can still run on python 3.10 11 | # 12 | # Remove this once rflink, fido, and blackbird 13 | # have had their libraries updated to remove 14 | # asyncio.coroutine 15 | from asyncio import base_futures, constants, format_helpers 16 | from asyncio.coroutines import _is_coroutine 17 | import collections.abc 18 | import functools 19 | import inspect 20 | import logging 21 | import traceback 22 | import types 23 | import warnings 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class CoroWrapper: 29 | # Wrapper for coroutine object in _DEBUG mode. 30 | 31 | def __init__(self, gen, func=None): 32 | assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen 33 | self.gen = gen 34 | self.func = func # Used to unwrap @coroutine decorator 35 | self._source_traceback = format_helpers.extract_stack(sys._getframe(1)) 36 | self.__name__ = getattr(gen, "__name__", None) 37 | self.__qualname__ = getattr(gen, "__qualname__", None) 38 | 39 | def __iter__(self): 40 | return self 41 | 42 | def __next__(self): 43 | return self.gen.send(None) 44 | 45 | def send(self, value): 46 | return self.gen.send(value) 47 | 48 | def throw(self, type, value=None, traceback=None): 49 | return self.gen.throw(type, value, traceback) 50 | 51 | def close(self): 52 | return self.gen.close() 53 | 54 | @property 55 | def gi_frame(self): 56 | return self.gen.gi_frame 57 | 58 | @property 59 | def gi_running(self): 60 | return self.gen.gi_running 61 | 62 | @property 63 | def gi_code(self): 64 | return self.gen.gi_code 65 | 66 | def __await__(self): 67 | return self 68 | 69 | @property 70 | def gi_yieldfrom(self): 71 | return self.gen.gi_yieldfrom 72 | 73 | def __del__(self): 74 | # Be careful accessing self.gen.frame -- self.gen might not exist. 75 | gen = getattr(self, "gen", None) 76 | frame = getattr(gen, "gi_frame", None) 77 | if frame is not None and frame.f_lasti == -1: 78 | msg = f"{self!r} was never yielded from" 79 | tb = getattr(self, "_source_traceback", ()) 80 | if tb: 81 | tb = "".join(traceback.format_list(tb)) 82 | msg += ( 83 | f"\nCoroutine object created at " 84 | f"(most recent call last, truncated to " 85 | f"{constants.DEBUG_STACK_DEPTH} last lines):\n" 86 | ) 87 | msg += tb.rstrip() 88 | logger.error(msg) 89 | 90 | 91 | def legacy_coroutine(func): 92 | """Decorator to mark coroutines. 93 | If the coroutine is not yielded from before it is destroyed, 94 | an error message is logged. 95 | """ 96 | warnings.warn( 97 | '"@coroutine" decorator is deprecated since Python 3.8, use "async def" instead', 98 | DeprecationWarning, 99 | stacklevel=2, 100 | ) 101 | if inspect.iscoroutinefunction(func): 102 | # In Python 3.5 that's all we need to do for coroutines 103 | # defined with "async def". 104 | return func 105 | 106 | if inspect.isgeneratorfunction(func): 107 | coro = func 108 | else: 109 | 110 | @functools.wraps(func) 111 | def coro(*args, **kw): 112 | res = func(*args, **kw) 113 | if ( 114 | base_futures.isfuture(res) 115 | or inspect.isgenerator(res) 116 | or isinstance(res, CoroWrapper) 117 | ): 118 | res = yield from res 119 | else: 120 | # If 'res' is an awaitable, run it. 121 | try: 122 | await_meth = res.__await__ 123 | except AttributeError: 124 | pass 125 | else: 126 | if isinstance(res, collections.abc.Awaitable): 127 | res = yield from await_meth() 128 | return res 129 | 130 | wrapper = types.coroutine(coro) 131 | wrapper._is_coroutine = _is_coroutine # For iscoroutinefunction(). 132 | return wrapper 133 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the helper method for writing . 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import asyncio 10 | from collections.abc import ( 11 | AsyncGenerator, 12 | Callable, 13 | Coroutine, 14 | Generator, 15 | Iterable, 16 | Iterator, 17 | Mapping, 18 | Sequence, 19 | ) 20 | from contextlib import asynccontextmanager, contextmanager, suppress 21 | from datetime import UTC, datetime, timedelta 22 | from enum import Enum, StrEnum 23 | import functools as ft 24 | from functools import lru_cache 25 | from io import StringIO 26 | import json 27 | import logging 28 | import os 29 | import pathlib 30 | import time 31 | import traceback 32 | from types import FrameType, ModuleType 33 | from typing import Any, Literal, NoReturn 34 | from unittest.mock import AsyncMock, Mock, patch 35 | 36 | from aiohttp.test_utils import unused_port as get_test_instance_port 37 | from annotatedyaml import load_yaml_dict, loader as yaml_loader 38 | import attr 39 | import pytest 40 | from syrupy.assertion import SnapshotAssertion 41 | import voluptuous as vol 42 | 43 | from homeassistant import auth, bootstrap, config_entries, loader 44 | from homeassistant.auth import ( 45 | auth_store, 46 | models as auth_models, 47 | permissions as auth_permissions, 48 | providers as auth_providers, 49 | ) 50 | from homeassistant.auth.permissions import system_policies 51 | from homeassistant.components import device_automation, persistent_notification as pn 52 | from homeassistant.components.device_automation import ( 53 | _async_get_device_automation_capabilities as async_get_device_automation_capabilities, 54 | ) 55 | from homeassistant.components.logger import ( 56 | DOMAIN as LOGGER_DOMAIN, 57 | SERVICE_SET_LEVEL, 58 | _clear_logger_overwrites, 59 | ) 60 | from homeassistant.config import IntegrationConfigInfo, async_process_component_config 61 | from homeassistant.config_entries import ConfigEntry, ConfigFlow, ConfigFlowResult 62 | from homeassistant.const import ( 63 | DEVICE_DEFAULT_NAME, 64 | EVENT_HOMEASSISTANT_CLOSE, 65 | EVENT_HOMEASSISTANT_STOP, 66 | EVENT_STATE_CHANGED, 67 | STATE_OFF, 68 | STATE_ON, 69 | ) 70 | from homeassistant.core import ( 71 | CoreState, 72 | Event, 73 | HomeAssistant, 74 | ServiceCall, 75 | ServiceResponse, 76 | State, 77 | SupportsResponse, 78 | callback, 79 | ) 80 | from homeassistant.helpers import ( 81 | area_registry as ar, 82 | category_registry as cr, 83 | device_registry as dr, 84 | entity, 85 | entity_platform, 86 | entity_registry as er, 87 | event, 88 | floor_registry as fr, 89 | intent, 90 | issue_registry as ir, 91 | label_registry as lr, 92 | restore_state as rs, 93 | storage, 94 | translation, 95 | ) 96 | from homeassistant.helpers.dispatcher import ( 97 | async_dispatcher_connect, 98 | async_dispatcher_send, 99 | ) 100 | from homeassistant.helpers.entity import Entity 101 | from homeassistant.helpers.entity_platform import ( 102 | AddConfigEntryEntitiesCallback, 103 | AddEntitiesCallback, 104 | ) 105 | from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, json_dumps 106 | from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType 107 | from homeassistant.util import dt as dt_util, ulid as ulid_util, uuid as uuid_util 108 | from homeassistant.util.async_ import ( 109 | _SHUTDOWN_RUN_CALLBACK_THREADSAFE, 110 | get_scheduled_timer_handles, 111 | run_callback_threadsafe, 112 | ) 113 | from homeassistant.util.event_type import EventType 114 | from homeassistant.util.json import ( 115 | JsonArrayType, 116 | JsonObjectType, 117 | JsonValueType, 118 | json_loads, 119 | json_loads_array, 120 | json_loads_object, 121 | ) 122 | from homeassistant.util.signal_type import SignalType 123 | from homeassistant.util.unit_system import METRIC_SYSTEM 124 | 125 | from .testing_config.custom_components.test_constant_deprecation import ( 126 | import_deprecated_constant, 127 | ) 128 | 129 | __all__ = [ 130 | "async_get_device_automation_capabilities", 131 | "get_test_instance_port", 132 | ] 133 | 134 | _LOGGER = logging.getLogger(__name__) 135 | INSTANCES = [] 136 | CLIENT_ID = "https://example.com/app" 137 | CLIENT_REDIRECT_URI = "https://example.com/app/callback" 138 | 139 | 140 | class QualityScaleStatus(StrEnum): 141 | """Source of core configuration.""" 142 | 143 | DONE = "done" 144 | EXEMPT = "exempt" 145 | TODO = "todo" 146 | 147 | 148 | async def async_get_device_automations( 149 | hass: HomeAssistant, 150 | automation_type: device_automation.DeviceAutomationType, 151 | device_id: str, 152 | ) -> Any: 153 | """Get a device automation for a single device id.""" 154 | automations = await device_automation.async_get_device_automations( 155 | hass, automation_type, [device_id] 156 | ) 157 | return automations.get(device_id) 158 | 159 | 160 | def threadsafe_callback_factory(func): 161 | """Create threadsafe functions out of callbacks. 162 | 163 | Callback needs to have `hass` as first argument. 164 | """ 165 | 166 | @ft.wraps(func) 167 | def threadsafe(*args, **kwargs): 168 | """Call func threadsafe.""" 169 | hass = args[0] 170 | return run_callback_threadsafe( 171 | hass.loop, ft.partial(func, *args, **kwargs) 172 | ).result() 173 | 174 | return threadsafe 175 | 176 | 177 | def threadsafe_coroutine_factory(func): 178 | """Create threadsafe functions out of coroutine. 179 | 180 | Callback needs to have `hass` as first argument. 181 | """ 182 | 183 | @ft.wraps(func) 184 | def threadsafe(*args, **kwargs): 185 | """Call func threadsafe.""" 186 | hass = args[0] 187 | return asyncio.run_coroutine_threadsafe( 188 | func(*args, **kwargs), hass.loop 189 | ).result() 190 | 191 | return threadsafe 192 | 193 | 194 | def get_test_config_dir(*add_path): 195 | """Return a path to a test config dir.""" 196 | return os.path.join(os.path.dirname(__file__), "testing_config", *add_path) 197 | 198 | 199 | class StoreWithoutWriteLoad[_T: (Mapping[str, Any] | Sequence[Any])](storage.Store[_T]): 200 | """Fake store that does not write or load. Used for testing.""" 201 | 202 | async def async_save(self, *args: Any, **kwargs: Any) -> None: 203 | """Save the data. 204 | 205 | This function is mocked out in . 206 | """ 207 | 208 | @callback 209 | def async_save_delay(self, *args: Any, **kwargs: Any) -> None: 210 | """Save data with an optional delay. 211 | 212 | This function is mocked out in . 213 | """ 214 | 215 | 216 | @asynccontextmanager 217 | async def async_test_home_assistant( 218 | event_loop: asyncio.AbstractEventLoop | None = None, 219 | load_registries: bool = True, 220 | config_dir: str | None = None, 221 | initial_state: CoreState = CoreState.running, 222 | ) -> AsyncGenerator[HomeAssistant]: 223 | """Return a Home Assistant object pointing at test config dir.""" 224 | hass = HomeAssistant(config_dir or get_test_config_dir()) 225 | store = auth_store.AuthStore(hass) 226 | hass.auth = auth.AuthManager(hass, store, {}, {}) 227 | ensure_auth_manager_loaded(hass.auth) 228 | INSTANCES.append(hass) 229 | 230 | orig_async_add_job = hass.async_add_job 231 | orig_async_add_executor_job = hass.async_add_executor_job 232 | orig_async_create_task_internal = hass.async_create_task_internal 233 | orig_tz = dt_util.get_default_time_zone() 234 | 235 | def async_add_job(target, *args, eager_start: bool = False): 236 | """Add job.""" 237 | check_target = target 238 | while isinstance(check_target, ft.partial): 239 | check_target = check_target.func 240 | 241 | if isinstance(check_target, Mock) and not isinstance(target, AsyncMock): 242 | fut = asyncio.Future() 243 | fut.set_result(target(*args)) 244 | return fut 245 | 246 | return orig_async_add_job(target, *args, eager_start=eager_start) 247 | 248 | def async_add_executor_job(target, *args): 249 | """Add executor job.""" 250 | check_target = target 251 | while isinstance(check_target, ft.partial): 252 | check_target = check_target.func 253 | 254 | if isinstance(check_target, Mock): 255 | fut = asyncio.Future() 256 | fut.set_result(target(*args)) 257 | return fut 258 | 259 | return orig_async_add_executor_job(target, *args) 260 | 261 | def async_create_task_internal(coroutine, name=None, eager_start=True): 262 | """Create task.""" 263 | if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock): 264 | fut = asyncio.Future() 265 | fut.set_result(None) 266 | return fut 267 | 268 | return orig_async_create_task_internal(coroutine, name, eager_start) 269 | 270 | hass.async_add_job = async_add_job 271 | hass.async_add_executor_job = async_add_executor_job 272 | hass.async_create_task_internal = async_create_task_internal 273 | 274 | hass.data[loader.DATA_CUSTOM_COMPONENTS] = {} 275 | 276 | hass.config.location_name = "test home" 277 | hass.config.latitude = 32.87336 278 | hass.config.longitude = -117.22743 279 | hass.config.elevation = 0 280 | await hass.config.async_set_time_zone("US/Pacific") 281 | hass.config.units = METRIC_SYSTEM 282 | hass.config.media_dirs = {"local": get_test_config_dir("media")} 283 | hass.config.skip_pip = True 284 | hass.config.skip_pip_packages = [] 285 | 286 | hass.config_entries = config_entries.ConfigEntries( 287 | hass, 288 | { 289 | "_": ( 290 | "Not empty or else some bad checks for hass config in discovery.py" 291 | " breaks" 292 | ) 293 | }, 294 | ) 295 | hass.bus.async_listen_once( 296 | EVENT_HOMEASSISTANT_STOP, 297 | hass.config_entries._async_shutdown, 298 | ) 299 | 300 | # Load the registries 301 | entity.async_setup(hass) 302 | loader.async_setup(hass) 303 | 304 | # setup translation cache instead of calling translation.async_setup(hass) 305 | hass.data[translation.TRANSLATION_FLATTEN_CACHE] = translation._TranslationCache( 306 | hass 307 | ) 308 | if load_registries: 309 | with ( 310 | patch.object(StoreWithoutWriteLoad, "async_load", return_value=None), 311 | patch( 312 | "homeassistant.helpers.area_registry.AreaRegistryStore", 313 | StoreWithoutWriteLoad, 314 | ), 315 | patch( 316 | "homeassistant.helpers.device_registry.DeviceRegistryStore", 317 | StoreWithoutWriteLoad, 318 | ), 319 | patch( 320 | "homeassistant.helpers.entity_registry.EntityRegistryStore", 321 | StoreWithoutWriteLoad, 322 | ), 323 | patch( 324 | "homeassistant.helpers.storage.Store", # Floor & label registry are different 325 | StoreWithoutWriteLoad, 326 | ), 327 | patch( 328 | "homeassistant.helpers.issue_registry.IssueRegistryStore", 329 | StoreWithoutWriteLoad, 330 | ), 331 | patch( 332 | "homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump", 333 | return_value=None, 334 | ), 335 | patch( 336 | "homeassistant.helpers.restore_state.start.async_at_start", 337 | ), 338 | ): 339 | await ar.async_load(hass) 340 | await cr.async_load(hass) 341 | await dr.async_load(hass) 342 | await er.async_load(hass) 343 | await fr.async_load(hass) 344 | await ir.async_load(hass) 345 | await lr.async_load(hass) 346 | await rs.async_load(hass) 347 | hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None 348 | 349 | hass.set_state(initial_state) 350 | 351 | @callback 352 | def clear_instance(event): 353 | """Clear global instance.""" 354 | # Give aiohttp one loop iteration to close 355 | hass.loop.call_soon(INSTANCES.remove, hass) 356 | 357 | hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance) 358 | 359 | try: 360 | yield hass 361 | finally: 362 | # Restore timezone, it is set when creating the hass object 363 | dt_util.set_default_time_zone(orig_tz) 364 | # Remove loop shutdown indicator to not interfere with additional hass objects 365 | with suppress(AttributeError): 366 | delattr(hass.loop, _SHUTDOWN_RUN_CALLBACK_THREADSAFE) 367 | 368 | 369 | def async_mock_service( 370 | hass: HomeAssistant, 371 | domain: str, 372 | service: str, 373 | schema: vol.Schema | None = None, 374 | response: ServiceResponse = None, 375 | supports_response: SupportsResponse | None = None, 376 | raise_exception: Exception | None = None, 377 | ) -> list[ServiceCall]: 378 | """Set up a fake service & return a calls log list to this service.""" 379 | calls = [] 380 | 381 | @callback 382 | def mock_service_log(call): 383 | """Mock service call.""" 384 | calls.append(call) 385 | if raise_exception is not None: 386 | raise raise_exception 387 | return response 388 | 389 | if supports_response is None: 390 | if response is not None: 391 | supports_response = SupportsResponse.OPTIONAL 392 | else: 393 | supports_response = SupportsResponse.NONE 394 | 395 | hass.services.async_register( 396 | domain, 397 | service, 398 | mock_service_log, 399 | schema=schema, 400 | supports_response=supports_response, 401 | ) 402 | 403 | return calls 404 | 405 | 406 | mock_service = threadsafe_callback_factory(async_mock_service) 407 | 408 | 409 | @callback 410 | def async_mock_intent(hass: HomeAssistant, intent_typ: str) -> list[intent.Intent]: 411 | """Set up a fake intent handler.""" 412 | intents: list[intent.Intent] = [] 413 | 414 | class MockIntentHandler(intent.IntentHandler): 415 | intent_type = intent_typ 416 | 417 | async def async_handle( 418 | self, intent_obj: intent.Intent 419 | ) -> intent.IntentResponse: 420 | """Handle the intent.""" 421 | intents.append(intent_obj) 422 | return intent_obj.create_response() 423 | 424 | intent.async_register(hass, MockIntentHandler()) 425 | 426 | return intents 427 | 428 | 429 | class MockMqttReasonCode: 430 | """Class to fake a MQTT ReasonCode.""" 431 | 432 | value: int 433 | is_failure: bool 434 | 435 | def __init__( 436 | self, value: int = 0, is_failure: bool = False, name: str = "Success" 437 | ) -> None: 438 | """Initialize the mock reason code.""" 439 | self.value = value 440 | self.is_failure = is_failure 441 | self._name = name 442 | 443 | def getName(self) -> str: 444 | """Return the name of the reason code.""" 445 | return self._name 446 | 447 | 448 | @callback 449 | def async_fire_mqtt_message( 450 | hass: HomeAssistant, 451 | topic: str, 452 | payload: bytes | str, 453 | qos: int = 0, 454 | retain: bool = False, 455 | ) -> None: 456 | """Fire the MQTT message.""" 457 | # Local import to avoid processing MQTT modules when running a testcase 458 | # which does not use MQTT. 459 | 460 | # pylint: disable-next=import-outside-toplevel 461 | from paho.mqtt.client import MQTTMessage 462 | 463 | # pylint: disable-next=import-outside-toplevel 464 | from homeassistant.components.mqtt import MqttData 465 | 466 | if isinstance(payload, str): 467 | payload = payload.encode("utf-8") 468 | 469 | msg = MQTTMessage(topic=topic.encode("utf-8")) 470 | msg.payload = payload 471 | msg.qos = qos 472 | msg.retain = retain 473 | msg.timestamp = time.monotonic() 474 | 475 | mqtt_data: MqttData = hass.data["mqtt"] 476 | assert mqtt_data.client 477 | mqtt_data.client._async_mqtt_on_message(Mock(), None, msg) 478 | 479 | 480 | fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) 481 | 482 | 483 | @callback 484 | def async_fire_time_changed_exact( 485 | hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False 486 | ) -> None: 487 | """Fire a time changed event at an exact microsecond. 488 | 489 | Consider that it is not possible to actually achieve an exact 490 | microsecond in production as the event loop is not precise enough. 491 | If your code relies on this level of precision, consider a different 492 | approach, as this is only for testing. 493 | """ 494 | if datetime_ is None: 495 | utc_datetime = datetime.now(UTC) 496 | else: 497 | utc_datetime = dt_util.as_utc(datetime_) 498 | 499 | _async_fire_time_changed(hass, utc_datetime, fire_all) 500 | 501 | 502 | @callback 503 | def async_fire_time_changed( 504 | hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False 505 | ) -> None: 506 | """Fire a time changed event. 507 | 508 | If called within the first 500 ms of a second, time will be bumped to exactly 509 | 500 ms to match the async_track_utc_time_change event listeners and 510 | DataUpdateCoordinator which spreads all updates between 0.05..0.50. 511 | Background in PR https://github.com/home-assistant/core/pull/82233 512 | 513 | As asyncio is cooperative, we can't guarantee that the event loop will 514 | run an event at the exact time we want. If you need to fire time changed 515 | for an exact microsecond, use async_fire_time_changed_exact. 516 | """ 517 | if datetime_ is None: 518 | utc_datetime = datetime.now(UTC) 519 | else: 520 | utc_datetime = dt_util.as_utc(datetime_) 521 | 522 | # Increase the mocked time by 0.5 s to account for up to 0.5 s delay 523 | # added to events scheduled by update_coordinator and async_track_time_interval 524 | utc_datetime += timedelta(microseconds=event.RANDOM_MICROSECOND_MAX) 525 | 526 | _async_fire_time_changed(hass, utc_datetime, fire_all) 527 | 528 | 529 | _MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution 530 | 531 | 532 | @callback 533 | def _async_fire_time_changed( 534 | hass: HomeAssistant, utc_datetime: datetime | None, fire_all: bool 535 | ) -> None: 536 | timestamp = utc_datetime.timestamp() 537 | for task in list(get_scheduled_timer_handles(hass.loop)): 538 | if not isinstance(task, asyncio.TimerHandle): 539 | continue 540 | if task.cancelled(): 541 | continue 542 | 543 | mock_seconds_into_future = timestamp - time.time() 544 | future_seconds = task.when() - (hass.loop.time() + _MONOTONIC_RESOLUTION) 545 | 546 | if fire_all or mock_seconds_into_future >= future_seconds: 547 | with ( 548 | patch( 549 | "homeassistant.helpers.event.time_tracker_utcnow", 550 | return_value=utc_datetime, 551 | ), 552 | patch( 553 | "homeassistant.helpers.event.time_tracker_timestamp", 554 | return_value=timestamp, 555 | ), 556 | ): 557 | task._run() 558 | task.cancel() 559 | 560 | 561 | fire_time_changed = threadsafe_callback_factory(async_fire_time_changed) 562 | 563 | 564 | def get_fixture_path(filename: str, integration: str | None = None) -> pathlib.Path: 565 | """Get path of fixture.""" 566 | start_path = (current_file := traceback.extract_stack()[idx:=-1].filename) 567 | while start_path == current_file: 568 | start_path = traceback.extract_stack()[idx:=idx-1].filename 569 | if integration is None and "/" in filename and not filename.startswith("helpers/"): 570 | integration, filename = filename.split("/", 1) 571 | 572 | if integration is None: 573 | return pathlib.Path(start_path).parent.joinpath("fixtures", filename) 574 | 575 | return pathlib.Path(start_path).parent.joinpath( 576 | "components", integration, "fixtures", filename 577 | ) 578 | 579 | 580 | @lru_cache 581 | def load_fixture(filename: str, integration: str | None = None) -> str: 582 | """Load a fixture.""" 583 | return get_fixture_path(filename, integration).read_text(encoding="utf8") 584 | 585 | 586 | async def async_load_fixture( 587 | hass: HomeAssistant, filename: str, integration: str | None = None 588 | ) -> str: 589 | """Load a fixture.""" 590 | return await hass.async_add_executor_job(load_fixture, filename, integration) 591 | 592 | 593 | def load_json_value_fixture( 594 | filename: str, integration: str | None = None 595 | ) -> JsonValueType: 596 | """Load a JSON value from a fixture.""" 597 | return json_loads(load_fixture(filename, integration)) 598 | 599 | 600 | def load_json_array_fixture( 601 | filename: str, integration: str | None = None 602 | ) -> JsonArrayType: 603 | """Load a JSON array from a fixture.""" 604 | return json_loads_array(load_fixture(filename, integration)) 605 | 606 | 607 | async def async_load_json_array_fixture( 608 | hass: HomeAssistant, filename: str, integration: str | None = None 609 | ) -> JsonArrayType: 610 | """Load a JSON object from a fixture.""" 611 | return json_loads_array(await async_load_fixture(hass, filename, integration)) 612 | 613 | 614 | def load_json_object_fixture( 615 | filename: str, integration: str | None = None 616 | ) -> JsonObjectType: 617 | """Load a JSON object from a fixture.""" 618 | return json_loads_object(load_fixture(filename, integration)) 619 | 620 | 621 | async def async_load_json_object_fixture( 622 | hass: HomeAssistant, filename: str, integration: str | None = None 623 | ) -> JsonObjectType: 624 | """Load a JSON object from a fixture.""" 625 | return json_loads_object(await async_load_fixture(hass, filename, integration)) 626 | 627 | 628 | def json_round_trip(obj: Any) -> Any: 629 | """Round trip an object to JSON.""" 630 | return json_loads(json_dumps(obj)) 631 | 632 | 633 | def mock_state_change_event( 634 | hass: HomeAssistant, new_state: State, old_state: State | None = None 635 | ) -> None: 636 | """Mock state change event.""" 637 | event_data = { 638 | "entity_id": new_state.entity_id, 639 | "new_state": new_state, 640 | "old_state": old_state, 641 | } 642 | hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context) 643 | 644 | 645 | @callback 646 | def mock_component(hass: HomeAssistant, component: str) -> None: 647 | """Mock a component is setup.""" 648 | if component in hass.config.components: 649 | raise AssertionError(f"Integration {component} is already setup") 650 | 651 | hass.config.components.add(component) 652 | 653 | 654 | def mock_registry( 655 | hass: HomeAssistant, 656 | mock_entries: dict[str, er.RegistryEntry] | None = None, 657 | ) -> er.EntityRegistry: 658 | """Mock the Entity Registry. 659 | 660 | This should only be used if you need to mock/re-stage a clean mocked 661 | entity registry in your current hass object. It can be useful to, 662 | for example, pre-load the registry with items. 663 | 664 | This mock will thus replace the existing registry in the running hass. 665 | 666 | If you just need to access the existing registry, use the `entity_registry` 667 | fixture instead. 668 | """ 669 | registry = er.EntityRegistry(hass) 670 | if mock_entries is None: 671 | mock_entries = {} 672 | registry.deleted_entities = {} 673 | registry.entities = er.EntityRegistryItems() 674 | registry._entities_data = registry.entities.data 675 | for key, entry in mock_entries.items(): 676 | registry.entities[key] = entry 677 | 678 | hass.data[er.DATA_REGISTRY] = registry 679 | er.async_get.cache_clear() 680 | return registry 681 | 682 | 683 | @attr.s(frozen=True, kw_only=True, slots=True) 684 | class RegistryEntryWithDefaults(er.RegistryEntry): 685 | """Helper to create a registry entry with defaults.""" 686 | 687 | capabilities: Mapping[str, Any] | None = attr.ib(default=None) 688 | config_entry_id: str | None = attr.ib(default=None) 689 | config_subentry_id: str | None = attr.ib(default=None) 690 | created_at: datetime = attr.ib(factory=dt_util.utcnow) 691 | device_id: str | None = attr.ib(default=None) 692 | disabled_by: er.RegistryEntryDisabler | None = attr.ib(default=None) 693 | entity_category: er.EntityCategory | None = attr.ib(default=None) 694 | hidden_by: er.RegistryEntryHider | None = attr.ib(default=None) 695 | id: str = attr.ib( 696 | default=None, 697 | converter=attr.converters.default_if_none(factory=uuid_util.random_uuid_hex), # type: ignore[misc] 698 | ) 699 | has_entity_name: bool = attr.ib(default=False) 700 | options: er.ReadOnlyEntityOptionsType = attr.ib( 701 | default=None, converter=er._protect_entity_options 702 | ) 703 | original_device_class: str | None = attr.ib(default=None) 704 | original_icon: str | None = attr.ib(default=None) 705 | original_name: str | None = attr.ib(default=None) 706 | suggested_object_id: str | None = attr.ib(default=None) 707 | supported_features: int = attr.ib(default=0) 708 | translation_key: str | None = attr.ib(default=None) 709 | unit_of_measurement: str | None = attr.ib(default=None) 710 | 711 | 712 | def mock_area_registry( 713 | hass: HomeAssistant, mock_entries: dict[str, ar.AreaEntry] | None = None 714 | ) -> ar.AreaRegistry: 715 | """Mock the Area Registry. 716 | 717 | This should only be used if you need to mock/re-stage a clean mocked 718 | area registry in your current hass object. It can be useful to, 719 | for example, pre-load the registry with items. 720 | 721 | This mock will thus replace the existing registry in the running hass. 722 | 723 | If you just need to access the existing registry, use the `area_registry` 724 | fixture instead. 725 | """ 726 | registry = ar.AreaRegistry(hass) 727 | registry.areas = ar.AreaRegistryItems() 728 | for key, entry in mock_entries.items(): 729 | registry.areas[key] = entry 730 | 731 | hass.data[ar.DATA_REGISTRY] = registry 732 | ar.async_get.cache_clear() 733 | return registry 734 | 735 | 736 | def mock_device_registry( 737 | hass: HomeAssistant, 738 | mock_entries: dict[str, dr.DeviceEntry] | None = None, 739 | ) -> dr.DeviceRegistry: 740 | """Mock the Device Registry. 741 | 742 | This should only be used if you need to mock/re-stage a clean mocked 743 | device registry in your current hass object. It can be useful to, 744 | for example, pre-load the registry with items. 745 | 746 | This mock will thus replace the existing registry in the running hass. 747 | 748 | If you just need to access the existing registry, use the `device_registry` 749 | fixture instead. 750 | """ 751 | registry = dr.DeviceRegistry(hass) 752 | registry.devices = dr.ActiveDeviceRegistryItems() 753 | registry._device_data = registry.devices.data 754 | if mock_entries is None: 755 | mock_entries = {} 756 | for key, entry in mock_entries.items(): 757 | registry.devices[key] = entry 758 | registry.deleted_devices = dr.DeviceRegistryItems() 759 | 760 | hass.data[dr.DATA_REGISTRY] = registry 761 | dr.async_get.cache_clear() 762 | return registry 763 | 764 | 765 | class MockGroup(auth_models.Group): 766 | """Mock a group in Home Assistant.""" 767 | 768 | def __init__(self, id: str | None = None, name: str | None = "Mock Group") -> None: 769 | """Mock a group.""" 770 | kwargs = {"name": name, "policy": system_policies.ADMIN_POLICY} 771 | if id is not None: 772 | kwargs["id"] = id 773 | 774 | super().__init__(**kwargs) 775 | 776 | def add_to_hass(self, hass: HomeAssistant) -> MockGroup: 777 | """Test helper to add entry to hass.""" 778 | return self.add_to_auth_manager(hass.auth) 779 | 780 | def add_to_auth_manager(self, auth_mgr: auth.AuthManager) -> MockGroup: 781 | """Test helper to add entry to hass.""" 782 | ensure_auth_manager_loaded(auth_mgr) 783 | auth_mgr._store._groups[self.id] = self 784 | return self 785 | 786 | 787 | class MockUser(auth_models.User): 788 | """Mock a user in Home Assistant.""" 789 | 790 | def __init__( 791 | self, 792 | id: str | None = None, 793 | is_owner: bool = False, 794 | is_active: bool = True, 795 | name: str | None = "Mock User", 796 | system_generated: bool = False, 797 | groups: list[auth_models.Group] | None = None, 798 | ) -> None: 799 | """Initialize mock user.""" 800 | kwargs = { 801 | "is_owner": is_owner, 802 | "is_active": is_active, 803 | "name": name, 804 | "system_generated": system_generated, 805 | "groups": groups or [], 806 | "perm_lookup": None, 807 | } 808 | if id is not None: 809 | kwargs["id"] = id 810 | super().__init__(**kwargs) 811 | 812 | def add_to_hass(self, hass: HomeAssistant) -> MockUser: 813 | """Test helper to add entry to hass.""" 814 | return self.add_to_auth_manager(hass.auth) 815 | 816 | def add_to_auth_manager(self, auth_mgr: auth.AuthManager) -> MockUser: 817 | """Test helper to add entry to hass.""" 818 | ensure_auth_manager_loaded(auth_mgr) 819 | auth_mgr._store._users[self.id] = self 820 | return self 821 | 822 | def mock_policy(self, policy: auth_permissions.PolicyType) -> None: 823 | """Mock a policy for a user.""" 824 | self.permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup) 825 | 826 | 827 | async def register_auth_provider( 828 | hass: HomeAssistant, config: ConfigType 829 | ) -> auth_providers.AuthProvider: 830 | """Register an auth provider.""" 831 | provider = await auth_providers.auth_provider_from_config( 832 | hass, hass.auth._store, config 833 | ) 834 | assert provider is not None, "Invalid config specified" 835 | key = (provider.type, provider.id) 836 | providers = hass.auth._providers 837 | 838 | if key in providers: 839 | raise ValueError("Provider already registered") 840 | 841 | providers[key] = provider 842 | return provider 843 | 844 | 845 | @callback 846 | def ensure_auth_manager_loaded(auth_mgr: auth.AuthManager) -> None: 847 | """Ensure an auth manager is considered loaded.""" 848 | store = auth_mgr._store 849 | if store._users is None: 850 | store._set_defaults() 851 | 852 | 853 | class MockModule: 854 | """Representation of a fake module.""" 855 | 856 | def __init__( 857 | self, 858 | domain: str | None = None, 859 | *, 860 | dependencies: list[str] | None = None, 861 | setup: Callable[[HomeAssistant, ConfigType], bool] | None = None, 862 | requirements: list[str] | None = None, 863 | config_schema: vol.Schema | None = None, 864 | platform_schema: vol.Schema | None = None, 865 | platform_schema_base: vol.Schema | None = None, 866 | async_setup: Callable[[HomeAssistant, ConfigType], Coroutine[Any, Any, bool]] 867 | | None = None, 868 | async_setup_entry: Callable[ 869 | [HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool] 870 | ] 871 | | None = None, 872 | async_unload_entry: Callable[ 873 | [HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool] 874 | ] 875 | | None = None, 876 | async_migrate_entry: Callable[ 877 | [HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool] 878 | ] 879 | | None = None, 880 | async_remove_entry: Callable[ 881 | [HomeAssistant, ConfigEntry], Coroutine[Any, Any, None] 882 | ] 883 | | None = None, 884 | partial_manifest: dict[str, Any] | None = None, 885 | async_remove_config_entry_device: Callable[ 886 | [HomeAssistant, ConfigEntry, dr.DeviceEntry], Coroutine[Any, Any, bool] 887 | ] 888 | | None = None, 889 | ) -> None: 890 | """Initialize the mock module.""" 891 | self.__name__ = f"homeassistant.components.{domain}" 892 | self.__file__ = f"homeassistant/components/{domain}" 893 | self.DOMAIN = domain 894 | self.DEPENDENCIES = dependencies or [] 895 | self.REQUIREMENTS = requirements or [] 896 | # Overlay to be used when generating manifest from this module 897 | self._partial_manifest = partial_manifest 898 | 899 | if config_schema is not None: 900 | self.CONFIG_SCHEMA = config_schema 901 | 902 | if platform_schema is not None: 903 | self.PLATFORM_SCHEMA = platform_schema 904 | 905 | if platform_schema_base is not None: 906 | self.PLATFORM_SCHEMA_BASE = platform_schema_base 907 | 908 | if setup: 909 | # We run this in executor, wrap it in function 910 | # pylint: disable-next=unnecessary-lambda 911 | self.setup = lambda *args: setup(*args) 912 | 913 | if async_setup is not None: 914 | self.async_setup = async_setup 915 | 916 | if setup is None and async_setup is None: 917 | self.async_setup = AsyncMock(return_value=True) 918 | 919 | if async_setup_entry is not None: 920 | self.async_setup_entry = async_setup_entry 921 | 922 | if async_unload_entry is not None: 923 | self.async_unload_entry = async_unload_entry 924 | 925 | if async_migrate_entry is not None: 926 | self.async_migrate_entry = async_migrate_entry 927 | 928 | if async_remove_entry is not None: 929 | self.async_remove_entry = async_remove_entry 930 | 931 | if async_remove_config_entry_device is not None: 932 | self.async_remove_config_entry_device = async_remove_config_entry_device 933 | 934 | def mock_manifest(self): 935 | """Generate a mock manifest to represent this module.""" 936 | return { 937 | **loader.manifest_from_legacy_module(self.DOMAIN, self), 938 | **(self._partial_manifest or {}), 939 | } 940 | 941 | 942 | class MockPlatform: 943 | """Provide a fake platform.""" 944 | 945 | __name__ = "homeassistant.components.light.bla" 946 | __file__ = "homeassistant/components/blah/light" 947 | 948 | def __init__( 949 | self, 950 | *, 951 | setup_platform: Callable[ 952 | [HomeAssistant, ConfigType, AddEntitiesCallback, DiscoveryInfoType | None], 953 | None, 954 | ] 955 | | None = None, 956 | dependencies: list[str] | None = None, 957 | platform_schema: vol.Schema | None = None, 958 | async_setup_platform: Callable[ 959 | [HomeAssistant, ConfigType, AddEntitiesCallback, DiscoveryInfoType | None], 960 | Coroutine[Any, Any, None], 961 | ] 962 | | None = None, 963 | async_setup_entry: Callable[ 964 | [HomeAssistant, ConfigEntry, AddEntitiesCallback], Coroutine[Any, Any, None] 965 | ] 966 | | None = None, 967 | scan_interval: timedelta | None = None, 968 | ) -> None: 969 | """Initialize the platform.""" 970 | self.DEPENDENCIES = dependencies or [] 971 | 972 | if platform_schema is not None: 973 | self.PLATFORM_SCHEMA = platform_schema 974 | 975 | if scan_interval is not None: 976 | self.SCAN_INTERVAL = scan_interval 977 | 978 | if setup_platform is not None: 979 | # We run this in executor, wrap it in function 980 | # pylint: disable-next=unnecessary-lambda 981 | self.setup_platform = lambda *args: setup_platform(*args) 982 | 983 | if async_setup_platform is not None: 984 | self.async_setup_platform = async_setup_platform 985 | 986 | if async_setup_entry is not None: 987 | self.async_setup_entry = async_setup_entry 988 | 989 | if setup_platform is None and async_setup_platform is None: 990 | self.async_setup_platform = AsyncMock(return_value=None) 991 | 992 | 993 | class MockEntityPlatform(entity_platform.EntityPlatform): 994 | """Mock class with some mock defaults.""" 995 | 996 | def __init__( 997 | self, 998 | hass: HomeAssistant, 999 | logger=None, 1000 | domain="test_domain", 1001 | platform_name="test_platform", 1002 | platform=None, 1003 | scan_interval=timedelta(seconds=15), 1004 | entity_namespace=None, 1005 | ) -> None: 1006 | """Initialize a mock entity platform.""" 1007 | if logger is None: 1008 | logger = logging.getLogger("homeassistant.helpers.entity_platform") 1009 | 1010 | # Otherwise the constructor will blow up. 1011 | if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock): 1012 | platform.PARALLEL_UPDATES = 0 1013 | 1014 | super().__init__( 1015 | hass=hass, 1016 | logger=logger, 1017 | domain=domain, 1018 | platform_name=platform_name, 1019 | platform=platform, 1020 | scan_interval=scan_interval, 1021 | entity_namespace=entity_namespace, 1022 | ) 1023 | 1024 | @callback 1025 | def _async_on_stop(_: Event) -> None: 1026 | self.async_shutdown() 1027 | 1028 | hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_on_stop) 1029 | 1030 | 1031 | class MockToggleEntity(entity.ToggleEntity): 1032 | """Provide a mock toggle device.""" 1033 | 1034 | def __init__(self, name: str | None, state: Literal["on", "off"] | None) -> None: 1035 | """Initialize the mock entity.""" 1036 | self._name = name or DEVICE_DEFAULT_NAME 1037 | self._state = state 1038 | self.calls: list[tuple[str, dict[str, Any]]] = [] 1039 | 1040 | @property 1041 | def name(self) -> str: 1042 | """Return the name of the entity if any.""" 1043 | self.calls.append(("name", {})) 1044 | return self._name 1045 | 1046 | @property 1047 | def state(self) -> Literal["on", "off"] | None: 1048 | """Return the state of the entity if any.""" 1049 | self.calls.append(("state", {})) 1050 | return self._state 1051 | 1052 | @property 1053 | def is_on(self) -> bool: 1054 | """Return true if entity is on.""" 1055 | self.calls.append(("is_on", {})) 1056 | return self._state == STATE_ON 1057 | 1058 | def turn_on(self, **kwargs: Any) -> None: 1059 | """Turn the entity on.""" 1060 | self.calls.append(("turn_on", kwargs)) 1061 | self._state = STATE_ON 1062 | 1063 | def turn_off(self, **kwargs: Any) -> None: 1064 | """Turn the entity off.""" 1065 | self.calls.append(("turn_off", kwargs)) 1066 | self._state = STATE_OFF 1067 | 1068 | def last_call(self, method: str | None = None) -> tuple[str, dict[str, Any]]: 1069 | """Return the last call.""" 1070 | if not self.calls: 1071 | return None 1072 | if method is None: 1073 | return self.calls[-1] 1074 | try: 1075 | return next(call for call in reversed(self.calls) if call[0] == method) 1076 | except StopIteration: 1077 | return None 1078 | 1079 | 1080 | class MockConfigEntry(config_entries.ConfigEntry): 1081 | """Helper for creating config entries that adds some defaults.""" 1082 | 1083 | def __init__( 1084 | self, 1085 | *, 1086 | data=None, 1087 | disabled_by=None, 1088 | discovery_keys=None, 1089 | domain="test", 1090 | entry_id=None, 1091 | minor_version=1, 1092 | options=None, 1093 | pref_disable_new_entities=None, 1094 | pref_disable_polling=None, 1095 | reason=None, 1096 | source=config_entries.SOURCE_USER, 1097 | state=None, 1098 | subentries_data=None, 1099 | title="Mock Title", 1100 | unique_id=None, 1101 | version=1, 1102 | ) -> None: 1103 | """Initialize a mock config entry.""" 1104 | discovery_keys = discovery_keys or {} 1105 | kwargs = { 1106 | "data": data or {}, 1107 | "disabled_by": disabled_by, 1108 | "discovery_keys": discovery_keys, 1109 | "domain": domain, 1110 | "entry_id": entry_id or ulid_util.ulid_now(), 1111 | "minor_version": minor_version, 1112 | "options": options or {}, 1113 | "pref_disable_new_entities": pref_disable_new_entities, 1114 | "pref_disable_polling": pref_disable_polling, 1115 | "subentries_data": subentries_data or (), 1116 | "title": title, 1117 | "unique_id": unique_id, 1118 | "version": version, 1119 | } 1120 | if source is not None: 1121 | kwargs["source"] = source 1122 | if state is not None: 1123 | kwargs["state"] = state 1124 | super().__init__(**kwargs) 1125 | if reason is not None: 1126 | object.__setattr__(self, "reason", reason) 1127 | 1128 | def add_to_hass(self, hass: HomeAssistant) -> None: 1129 | """Test helper to add entry to hass.""" 1130 | hass.config_entries._entries[self.entry_id] = self 1131 | 1132 | def add_to_manager(self, manager: config_entries.ConfigEntries) -> None: 1133 | """Test helper to add entry to entry manager.""" 1134 | manager._entries[self.entry_id] = self 1135 | 1136 | def mock_state( 1137 | self, 1138 | hass: HomeAssistant, 1139 | state: config_entries.ConfigEntryState, 1140 | reason: str | None = None, 1141 | ) -> None: 1142 | """Mock the state of a config entry to be used in . 1143 | 1144 | Currently this is a wrapper around _async_set_state, but it may 1145 | change in the future. 1146 | 1147 | It is preferable to get the config entry into the desired state 1148 | by using the normal config entry methods, and this helper 1149 | is only intended to be used in cases where that is not possible. 1150 | 1151 | When in doubt, this helper should not be used in new code 1152 | and is only intended for backwards compatibility with existing 1153 | . 1154 | """ 1155 | self._async_set_state(hass, state, reason) 1156 | 1157 | async def start_reauth_flow( 1158 | self, 1159 | hass: HomeAssistant, 1160 | context: dict[str, Any] | None = None, 1161 | data: dict[str, Any] | None = None, 1162 | ) -> ConfigFlowResult: 1163 | """Start a reauthentication flow.""" 1164 | if self.entry_id not in hass.config_entries._entries: 1165 | raise ValueError("Config entry must be added to hass to start reauth flow") 1166 | return await start_reauth_flow(hass, self, context, data) 1167 | 1168 | async def start_reconfigure_flow( 1169 | self, 1170 | hass: HomeAssistant, 1171 | *, 1172 | show_advanced_options: bool = False, 1173 | ) -> ConfigFlowResult: 1174 | """Start a reconfiguration flow.""" 1175 | if self.entry_id not in hass.config_entries._entries: 1176 | raise ValueError( 1177 | "Config entry must be added to hass to start reconfiguration flow" 1178 | ) 1179 | return await hass.config_entries.flow.async_init( 1180 | self.domain, 1181 | context={ 1182 | "source": config_entries.SOURCE_RECONFIGURE, 1183 | "entry_id": self.entry_id, 1184 | "show_advanced_options": show_advanced_options, 1185 | }, 1186 | ) 1187 | 1188 | async def start_subentry_reconfigure_flow( 1189 | self, 1190 | hass: HomeAssistant, 1191 | subentry_flow_type: str, 1192 | subentry_id: str, 1193 | *, 1194 | show_advanced_options: bool = False, 1195 | ) -> ConfigFlowResult: 1196 | """Start a subentry reconfiguration flow.""" 1197 | if self.entry_id not in hass.config_entries._entries: 1198 | raise ValueError( 1199 | "Config entry must be added to hass to start reconfiguration flow" 1200 | ) 1201 | return await hass.config_entries.subentries.async_init( 1202 | (self.entry_id, subentry_flow_type), 1203 | context={ 1204 | "source": config_entries.SOURCE_RECONFIGURE, 1205 | "subentry_id": subentry_id, 1206 | "show_advanced_options": show_advanced_options, 1207 | }, 1208 | ) 1209 | 1210 | 1211 | async def start_reauth_flow( 1212 | hass: HomeAssistant, 1213 | entry: ConfigEntry, 1214 | context: dict[str, Any] | None = None, 1215 | data: dict[str, Any] | None = None, 1216 | ) -> ConfigFlowResult: 1217 | """Start a reauthentication flow for a config entry. 1218 | 1219 | This helper method should be aligned with `ConfigEntry._async_init_reauth`. 1220 | """ 1221 | return await hass.config_entries.flow.async_init( 1222 | entry.domain, 1223 | context={ 1224 | "source": config_entries.SOURCE_REAUTH, 1225 | "entry_id": entry.entry_id, 1226 | "title_placeholders": {"name": entry.title}, 1227 | "unique_id": entry.unique_id, 1228 | } 1229 | | (context or {}), 1230 | data=entry.data | (data or {}), 1231 | ) 1232 | 1233 | 1234 | def patch_yaml_files(files_dict, endswith=True): 1235 | """Patch load_yaml with a dictionary of yaml files.""" 1236 | # match using endswith, start search with longest string 1237 | matchlist = sorted(files_dict.keys(), key=len) if endswith else [] 1238 | 1239 | def mock_open_f(fname, **_): 1240 | """Mock open() in the yaml module, used by load_yaml.""" 1241 | # Return the mocked file on full match 1242 | if isinstance(fname, pathlib.Path): 1243 | fname = str(fname) 1244 | 1245 | if fname in files_dict: 1246 | _LOGGER.debug("patch_yaml_files match %s", fname) 1247 | res = StringIO(files_dict[fname]) 1248 | setattr(res, "name", fname) 1249 | return res 1250 | 1251 | # Match using endswith 1252 | for ends in matchlist: 1253 | if fname.endswith(ends): 1254 | _LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname) 1255 | res = StringIO(files_dict[ends]) 1256 | setattr(res, "name", fname) 1257 | return res 1258 | 1259 | # Fallback for hass.components (i.e. services.yaml) 1260 | if "homeassistant/components" in fname: 1261 | _LOGGER.debug("patch_yaml_files using real file: %s", fname) 1262 | return open(fname, encoding="utf-8") 1263 | 1264 | # Not found 1265 | raise FileNotFoundError(f"File not found: {fname}") 1266 | 1267 | return patch.object(yaml_loader, "open", mock_open_f, create=True) 1268 | 1269 | 1270 | @contextmanager 1271 | def assert_setup_component(count, domain=None): 1272 | """Collect valid configuration from setup_component. 1273 | 1274 | - count: The amount of valid platforms that should be setup 1275 | - domain: The domain to count is optional. It can be automatically 1276 | determined most of the time 1277 | 1278 | Use as a context manager around setup.setup_component 1279 | with assert_setup_component(0) as result_config: 1280 | setup_component(hass, domain, start_config) 1281 | # using result_config is optional 1282 | """ 1283 | config = {} 1284 | 1285 | async def mock_psc( 1286 | hass: HomeAssistant, 1287 | config_input: ConfigType, 1288 | integration: loader.Integration, 1289 | component: loader.ComponentProtocol | None = None, 1290 | ) -> IntegrationConfigInfo: 1291 | """Mock the prepare_setup_component to capture config.""" 1292 | domain_input = integration.domain 1293 | integration_config_info = await async_process_component_config( 1294 | hass, config_input, integration, component 1295 | ) 1296 | res = integration_config_info.config 1297 | config[domain_input] = None if res is None else res.get(domain_input) 1298 | _LOGGER.debug( 1299 | "Configuration for %s, Validated: %s, Original %s", 1300 | domain_input, 1301 | config[domain_input], 1302 | config_input.get(domain_input), 1303 | ) 1304 | return integration_config_info 1305 | 1306 | assert isinstance(config, dict) 1307 | with patch("homeassistant.config.async_process_component_config", mock_psc): 1308 | yield config 1309 | 1310 | if domain is None: 1311 | assert len(config) == 1, ( 1312 | f"assert_setup_component requires DOMAIN: {list(config.keys())}" 1313 | ) 1314 | domain = list(config.keys())[0] 1315 | 1316 | res = config.get(domain) 1317 | res_len = 0 if res is None else len(res) 1318 | assert res_len == count, ( 1319 | f"setup_component failed, expected {count} got {res_len}: {res}" 1320 | ) 1321 | 1322 | 1323 | def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None: 1324 | """Mock the DATA_RESTORE_CACHE.""" 1325 | key = rs.DATA_RESTORE_STATE 1326 | data = rs.RestoreStateData(hass) 1327 | now = dt_util.utcnow() 1328 | 1329 | last_states = {} 1330 | for state in states: 1331 | restored_state = state.as_dict() 1332 | restored_state = { 1333 | **restored_state, 1334 | "attributes": json.loads( 1335 | json.dumps(restored_state["attributes"], cls=JSONEncoder) 1336 | ), 1337 | } 1338 | last_states[state.entity_id] = rs.StoredState.from_dict( 1339 | {"state": restored_state, "last_seen": now} 1340 | ) 1341 | data.last_states = last_states 1342 | _LOGGER.debug("Restore cache: %s", data.last_states) 1343 | assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" 1344 | 1345 | rs.async_get.cache_clear() 1346 | hass.data[key] = data 1347 | 1348 | 1349 | def mock_restore_cache_with_extra_data( 1350 | hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]] 1351 | ) -> None: 1352 | """Mock the DATA_RESTORE_CACHE.""" 1353 | key = rs.DATA_RESTORE_STATE 1354 | data = rs.RestoreStateData(hass) 1355 | now = dt_util.utcnow() 1356 | 1357 | last_states = {} 1358 | for state, extra_data in states: 1359 | restored_state = state.as_dict() 1360 | restored_state = { 1361 | **restored_state, 1362 | "attributes": json.loads( 1363 | json.dumps(restored_state["attributes"], cls=JSONEncoder) 1364 | ), 1365 | } 1366 | last_states[state.entity_id] = rs.StoredState.from_dict( 1367 | {"state": restored_state, "extra_data": extra_data, "last_seen": now} 1368 | ) 1369 | data.last_states = last_states 1370 | _LOGGER.debug("Restore cache: %s", data.last_states) 1371 | assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" 1372 | 1373 | rs.async_get.cache_clear() 1374 | hass.data[key] = data 1375 | 1376 | 1377 | async def async_mock_restore_state_shutdown_restart( 1378 | hass: HomeAssistant, 1379 | ) -> rs.RestoreStateData: 1380 | """Mock shutting down and saving restore state and restoring.""" 1381 | data = rs.async_get(hass) 1382 | await data.async_dump_states() 1383 | await async_mock_load_restore_state_from_storage(hass) 1384 | return data 1385 | 1386 | 1387 | async def async_mock_load_restore_state_from_storage( 1388 | hass: HomeAssistant, 1389 | ) -> None: 1390 | """Mock loading restore state from storage. 1391 | 1392 | hass_storage must already be mocked. 1393 | """ 1394 | await rs.async_get(hass).async_load() 1395 | 1396 | 1397 | class MockEntity(entity.Entity): 1398 | """Mock Entity class.""" 1399 | 1400 | def __init__(self, **values: Any) -> None: 1401 | """Initialize an entity.""" 1402 | self._values = values 1403 | 1404 | if "entity_id" in values: 1405 | self.entity_id = values["entity_id"] 1406 | 1407 | @property 1408 | def available(self) -> bool: 1409 | """Return True if entity is available.""" 1410 | return self._handle("available") 1411 | 1412 | @property 1413 | def capability_attributes(self) -> Mapping[str, Any] | None: 1414 | """Info about capabilities.""" 1415 | return self._handle("capability_attributes") 1416 | 1417 | @property 1418 | def device_class(self) -> str | None: 1419 | """Info how device should be classified.""" 1420 | return self._handle("device_class") 1421 | 1422 | @property 1423 | def device_info(self) -> dr.DeviceInfo | None: 1424 | """Info how it links to a device.""" 1425 | return self._handle("device_info") 1426 | 1427 | @property 1428 | def entity_category(self) -> entity.EntityCategory | None: 1429 | """Return the entity category.""" 1430 | return self._handle("entity_category") 1431 | 1432 | @property 1433 | def extra_state_attributes(self) -> Mapping[str, Any] | None: 1434 | """Return entity specific state attributes.""" 1435 | return self._handle("extra_state_attributes") 1436 | 1437 | @property 1438 | def has_entity_name(self) -> bool: 1439 | """Return the has_entity_name name flag.""" 1440 | return self._handle("has_entity_name") 1441 | 1442 | @property 1443 | def entity_registry_enabled_default(self) -> bool: 1444 | """Return if the entity should be enabled when first added to the entity registry.""" 1445 | return self._handle("entity_registry_enabled_default") 1446 | 1447 | @property 1448 | def entity_registry_visible_default(self) -> bool: 1449 | """Return if the entity should be visible when first added to the entity registry.""" 1450 | return self._handle("entity_registry_visible_default") 1451 | 1452 | @property 1453 | def icon(self) -> str | None: 1454 | """Return the suggested icon.""" 1455 | return self._handle("icon") 1456 | 1457 | @property 1458 | def name(self) -> str | None: 1459 | """Return the name of the entity.""" 1460 | return self._handle("name") 1461 | 1462 | @property 1463 | def should_poll(self) -> bool: 1464 | """Return the ste of the polling.""" 1465 | return self._handle("should_poll") 1466 | 1467 | @property 1468 | def supported_features(self) -> int | None: 1469 | """Info about supported features.""" 1470 | return self._handle("supported_features") 1471 | 1472 | @property 1473 | def translation_key(self) -> str | None: 1474 | """Return the translation key.""" 1475 | return self._handle("translation_key") 1476 | 1477 | @property 1478 | def unique_id(self) -> str | None: 1479 | """Return the unique ID of the entity.""" 1480 | return self._handle("unique_id") 1481 | 1482 | @property 1483 | def unit_of_measurement(self) -> str | None: 1484 | """Info on the units the entity state is in.""" 1485 | return self._handle("unit_of_measurement") 1486 | 1487 | def _handle(self, attr: str) -> Any: 1488 | """Return attribute value.""" 1489 | if attr in self._values: 1490 | return self._values[attr] 1491 | return getattr(super(), attr) 1492 | 1493 | 1494 | @contextmanager 1495 | def mock_storage(data: dict[str, Any] | None = None) -> Generator[dict[str, Any]]: 1496 | """Mock storage. 1497 | 1498 | Data is a dict {'key': {'version': version, 'data': data}} 1499 | 1500 | Written data will be converted to JSON to ensure JSON parsing works. 1501 | """ 1502 | if data is None: 1503 | data = {} 1504 | 1505 | orig_load = storage.Store._async_load 1506 | 1507 | async def mock_async_load( 1508 | store: storage.Store, 1509 | ) -> dict[str, Any] | list[Any] | None: 1510 | """Mock version of load.""" 1511 | if store._data is None: 1512 | # No data to load 1513 | if store.key not in data: 1514 | # Make sure the next attempt will still load 1515 | store._load_task = None 1516 | return None 1517 | 1518 | mock_data = data.get(store.key) 1519 | 1520 | if "data" not in mock_data or "version" not in mock_data: 1521 | _LOGGER.error('Mock data needs "version" and "data"') 1522 | raise ValueError('Mock data needs "version" and "data"') 1523 | 1524 | store._data = mock_data 1525 | 1526 | # Route through original load so that we trigger migration 1527 | loaded = await orig_load(store) 1528 | _LOGGER.debug("Loading data for %s: %s", store.key, loaded) 1529 | return loaded 1530 | 1531 | async def mock_write_data( 1532 | store: storage.Store, path: str, data_to_write: dict[str, Any] 1533 | ) -> None: 1534 | """Mock version of write data.""" 1535 | # To ensure that the data can be serialized 1536 | _LOGGER.debug("Writing data to %s: %s", store.key, data_to_write) 1537 | raise_contains_mocks(data_to_write) 1538 | 1539 | if "data_func" in data_to_write: 1540 | data_to_write["data"] = data_to_write.pop("data_func")() 1541 | 1542 | encoder = store._encoder 1543 | if encoder and encoder is not JSONEncoder: 1544 | # If they pass a custom encoder that is not the 1545 | # default JSONEncoder, we use the slow path of json.dumps 1546 | dump = ft.partial(json.dumps, cls=store._encoder) 1547 | else: 1548 | dump = _orjson_default_encoder 1549 | data[store.key] = json_loads(dump(data_to_write)) 1550 | 1551 | async def mock_remove(store: storage.Store) -> None: 1552 | """Remove data.""" 1553 | data.pop(store.key, None) 1554 | 1555 | with ( 1556 | patch( 1557 | "homeassistant.helpers.storage.Store._async_load", 1558 | side_effect=mock_async_load, 1559 | autospec=True, 1560 | ), 1561 | patch( 1562 | "homeassistant.helpers.storage.Store._async_write_data", 1563 | side_effect=mock_write_data, 1564 | autospec=True, 1565 | ), 1566 | patch( 1567 | "homeassistant.helpers.storage.Store.async_remove", 1568 | side_effect=mock_remove, 1569 | autospec=True, 1570 | ), 1571 | ): 1572 | yield data 1573 | 1574 | 1575 | async def flush_store(store: storage.Store) -> None: 1576 | """Make sure all delayed writes of a store are written.""" 1577 | if store._data is None: 1578 | return 1579 | 1580 | store._async_cleanup_final_write_listener() 1581 | store._async_cleanup_delay_listener() 1582 | await store._async_handle_write_data() 1583 | 1584 | 1585 | async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, Any]: 1586 | """Get system health info.""" 1587 | return await hass.data["system_health"][domain].info_callback(hass) 1588 | 1589 | 1590 | @contextmanager 1591 | def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> Iterator[None]: 1592 | """Mock a config flow handler.""" 1593 | original_handler = config_entries.HANDLERS.get(domain) 1594 | config_entries.HANDLERS[domain] = config_flow 1595 | _LOGGER.info("Adding mock config flow: %s", domain) 1596 | yield 1597 | config_entries.HANDLERS.pop(domain) 1598 | if original_handler: 1599 | config_entries.HANDLERS[domain] = original_handler 1600 | 1601 | 1602 | def mock_integration( 1603 | hass: HomeAssistant, 1604 | module: MockModule, 1605 | built_in: bool = True, 1606 | top_level_files: set[str] | None = None, 1607 | ) -> loader.Integration: 1608 | """Mock an integration.""" 1609 | integration = loader.Integration( 1610 | hass, 1611 | f"{loader.PACKAGE_BUILTIN}.{module.DOMAIN}" 1612 | if built_in 1613 | else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}", 1614 | pathlib.Path(""), 1615 | module.mock_manifest(), 1616 | top_level_files, 1617 | ) 1618 | 1619 | def mock_import_platform(platform_name: str) -> NoReturn: 1620 | raise ImportError( 1621 | f"Mocked unable to import platform '{integration.pkg_path}.{platform_name}'", 1622 | name=f"{integration.pkg_path}.{platform_name}", 1623 | ) 1624 | 1625 | integration._import_platform = mock_import_platform 1626 | 1627 | _LOGGER.info("Adding mock integration: %s", module.DOMAIN) 1628 | integration_cache = hass.data[loader.DATA_INTEGRATIONS] 1629 | integration_cache[module.DOMAIN] = integration 1630 | 1631 | module_cache = hass.data[loader.DATA_COMPONENTS] 1632 | module_cache[module.DOMAIN] = module 1633 | 1634 | return integration 1635 | 1636 | 1637 | def mock_platform( 1638 | hass: HomeAssistant, 1639 | platform_path: str, 1640 | module: Mock | MockPlatform | None = None, 1641 | built_in=True, 1642 | ) -> None: 1643 | """Mock a platform. 1644 | 1645 | platform_path is in form hue.config_flow. 1646 | """ 1647 | domain, _, platform_name = platform_path.partition(".") 1648 | integration_cache = hass.data[loader.DATA_INTEGRATIONS] 1649 | module_cache = hass.data[loader.DATA_COMPONENTS] 1650 | 1651 | if domain not in integration_cache: 1652 | mock_integration(hass, MockModule(domain), built_in=built_in) 1653 | 1654 | integration_cache[domain]._top_level_files.add(f"{platform_name}.py") 1655 | _LOGGER.info("Adding mock integration platform: %s", platform_path) 1656 | module_cache[platform_path] = module or Mock() 1657 | 1658 | 1659 | def async_capture_events[_DataT: Mapping[str, Any] = dict[str, Any]]( 1660 | hass: HomeAssistant, event_name: EventType[_DataT] | str 1661 | ) -> list[Event[_DataT]]: 1662 | """Create a helper that captures events.""" 1663 | events: list[Event[_DataT]] = [] 1664 | 1665 | @callback 1666 | def capture_events(event: Event[_DataT]) -> None: 1667 | events.append(event) 1668 | 1669 | hass.bus.async_listen(event_name, capture_events) 1670 | 1671 | return events 1672 | 1673 | 1674 | @callback 1675 | def async_mock_signal[*_Ts]( 1676 | hass: HomeAssistant, signal: SignalType[*_Ts] | str 1677 | ) -> list[tuple[*_Ts]]: 1678 | """Catch all dispatches to a signal.""" 1679 | calls: list[tuple[*_Ts]] = [] 1680 | 1681 | @callback 1682 | def mock_signal_handler(*args: *_Ts) -> None: 1683 | """Mock service call.""" 1684 | calls.append(args) 1685 | 1686 | async_dispatcher_connect(hass, signal, mock_signal_handler) 1687 | 1688 | return calls 1689 | 1690 | 1691 | _SENTINEL = object() 1692 | 1693 | 1694 | class _HA_ANY: 1695 | """A helper object that compares equal to everything. 1696 | 1697 | Based on unittest.mock.ANY, but modified to not show up in pytest's equality 1698 | assertion diffs. 1699 | """ 1700 | 1701 | _other = _SENTINEL 1702 | 1703 | def __eq__(self, other: object) -> bool: 1704 | """Test equal.""" 1705 | self._other = other 1706 | return True 1707 | 1708 | def __ne__(self, other: object) -> bool: 1709 | """Test not equal.""" 1710 | self._other = other 1711 | return False 1712 | 1713 | def __repr__(self) -> str: 1714 | """Return repr() other to not show up in pytest quality diffs.""" 1715 | if self._other is _SENTINEL: 1716 | return "" 1717 | return repr(self._other) 1718 | 1719 | 1720 | ANY = _HA_ANY() 1721 | 1722 | 1723 | def raise_contains_mocks(val: Any) -> None: 1724 | """Raise for mocks.""" 1725 | if isinstance(val, Mock): 1726 | raise TypeError(val) 1727 | 1728 | if isinstance(val, dict): 1729 | for dict_value in val.values(): 1730 | raise_contains_mocks(dict_value) 1731 | 1732 | if isinstance(val, list): 1733 | for dict_value in val: 1734 | raise_contains_mocks(dict_value) 1735 | 1736 | 1737 | @callback 1738 | def async_get_persistent_notifications( 1739 | hass: HomeAssistant, 1740 | ) -> dict[str, pn.Notification]: 1741 | """Get the current persistent notifications.""" 1742 | return pn._async_get_or_create_notifications(hass) 1743 | 1744 | 1745 | def async_mock_cloud_connection_status(hass: HomeAssistant, connected: bool) -> None: 1746 | """Mock a signal the cloud disconnected.""" 1747 | # pylint: disable-next=import-outside-toplevel 1748 | from homeassistant.components.cloud import ( 1749 | SIGNAL_CLOUD_CONNECTION_STATE, 1750 | CloudConnectionState, 1751 | ) 1752 | 1753 | if connected: 1754 | state = CloudConnectionState.CLOUD_CONNECTED 1755 | else: 1756 | state = CloudConnectionState.CLOUD_DISCONNECTED 1757 | async_dispatcher_send(hass, SIGNAL_CLOUD_CONNECTION_STATE, state) 1758 | 1759 | 1760 | @asynccontextmanager 1761 | async def async_call_logger_set_level( 1762 | logger: str, 1763 | level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "CRITICAL"], 1764 | *, 1765 | hass: HomeAssistant, 1766 | caplog: pytest.LogCaptureFixture, 1767 | ) -> AsyncGenerator[None]: 1768 | """Context manager to reset loggers after logger.set_level call.""" 1769 | assert LOGGER_DOMAIN in hass.data, "'logger' integration not setup" 1770 | with caplog.at_level(logging.NOTSET, logger): 1771 | await hass.services.async_call( 1772 | LOGGER_DOMAIN, 1773 | SERVICE_SET_LEVEL, 1774 | {logger: level}, 1775 | blocking=True, 1776 | ) 1777 | await hass.async_block_till_done() 1778 | yield 1779 | _clear_logger_overwrites(hass) 1780 | 1781 | 1782 | def import_and_test_deprecated_constant_enum( 1783 | caplog: pytest.LogCaptureFixture, 1784 | module: ModuleType, 1785 | replacement: Enum, 1786 | constant_prefix: str, 1787 | breaks_in_ha_version: str, 1788 | ) -> None: 1789 | """Import and test deprecated constant replaced by a enum. 1790 | 1791 | - Import deprecated enum 1792 | - Assert value is the same as the replacement 1793 | - Assert a warning is logged 1794 | - Assert the deprecated constant is included in the modules.__dir__() 1795 | - Assert the deprecated constant is included in the modules.__all__() 1796 | """ 1797 | import_and_test_deprecated_constant( 1798 | caplog, 1799 | module, 1800 | constant_prefix + replacement.name, 1801 | f"{replacement.__class__.__name__}.{replacement.name}", 1802 | replacement, 1803 | breaks_in_ha_version, 1804 | ) 1805 | 1806 | 1807 | def import_and_test_deprecated_constant( 1808 | caplog: pytest.LogCaptureFixture, 1809 | module: ModuleType, 1810 | constant_name: str, 1811 | replacement_name: str, 1812 | replacement: Any, 1813 | breaks_in_ha_version: str, 1814 | ) -> None: 1815 | """Import and test deprecated constant replaced by a value. 1816 | 1817 | - Import deprecated constant 1818 | - Assert value is the same as the replacement 1819 | - Assert a warning is logged 1820 | - Assert the deprecated constant is included in the modules.__dir__() 1821 | - Assert the deprecated constant is included in the modules.__all__() 1822 | """ 1823 | value = import_deprecated_constant(module, constant_name) 1824 | assert value == replacement 1825 | assert ( 1826 | module.__name__, 1827 | logging.WARNING, 1828 | ( 1829 | f"{constant_name} was used from test_constant_deprecation," 1830 | f" this is a deprecated constant which will be removed in HA Core {breaks_in_ha_version}. " 1831 | f"Use {replacement_name} instead, please report " 1832 | "it to the author of the 'test_constant_deprecation' custom integration" 1833 | ), 1834 | ) in caplog.record_tuples 1835 | 1836 | # verify deprecated constant is included in dir() 1837 | assert constant_name in dir(module) 1838 | assert constant_name in module.__all__ 1839 | 1840 | 1841 | def import_and_test_deprecated_alias( 1842 | caplog: pytest.LogCaptureFixture, 1843 | module: ModuleType, 1844 | alias_name: str, 1845 | replacement: Any, 1846 | breaks_in_ha_version: str, 1847 | ) -> None: 1848 | """Import and test deprecated alias replaced by a value. 1849 | 1850 | - Import deprecated alias 1851 | - Assert value is the same as the replacement 1852 | - Assert a warning is logged 1853 | - Assert the deprecated alias is included in the modules.__dir__() 1854 | - Assert the deprecated alias is included in the modules.__all__() 1855 | """ 1856 | replacement_name = f"{replacement.__module__}.{replacement.__name__}" 1857 | value = import_deprecated_constant(module, alias_name) 1858 | assert value == replacement 1859 | assert ( 1860 | module.__name__, 1861 | logging.WARNING, 1862 | ( 1863 | f"{alias_name} was used from test_constant_deprecation," 1864 | f" this is a deprecated alias which will be removed in HA Core {breaks_in_ha_version}. " 1865 | f"Use {replacement_name} instead, please report " 1866 | "it to the author of the 'test_constant_deprecation' custom integration" 1867 | ), 1868 | ) in caplog.record_tuples 1869 | 1870 | # verify deprecated alias is included in dir() 1871 | assert alias_name in dir(module) 1872 | assert alias_name in module.__all__ 1873 | 1874 | 1875 | def help_test_all(module: ModuleType) -> None: 1876 | """Test module.__all__ is correctly set.""" 1877 | assert set(module.__all__) == { 1878 | itm for itm in dir(module) if not itm.startswith("_") 1879 | } 1880 | 1881 | 1882 | def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType: 1883 | """Convert an extract stack to a frame list.""" 1884 | stack = list(extract_stack) 1885 | _globals = globals() 1886 | for frame in stack: 1887 | frame.f_back = None 1888 | frame.f_globals = _globals 1889 | frame.f_code.co_filename = frame.filename 1890 | frame.f_lineno = int(frame.lineno) 1891 | 1892 | top_frame = stack.pop() 1893 | current_frame = top_frame 1894 | while stack and (next_frame := stack.pop()): 1895 | current_frame.f_back = next_frame 1896 | current_frame = next_frame 1897 | 1898 | return top_frame 1899 | 1900 | 1901 | def setup_test_component_platform( 1902 | hass: HomeAssistant, 1903 | domain: str, 1904 | entities: Iterable[Entity], 1905 | from_config_entry: bool = False, 1906 | built_in: bool = True, 1907 | ) -> MockPlatform: 1908 | """Mock a test component platform for .""" 1909 | 1910 | async def _async_setup_platform( 1911 | hass: HomeAssistant, 1912 | config: ConfigType, 1913 | async_add_entities: AddEntitiesCallback, 1914 | discovery_info: DiscoveryInfoType | None = None, 1915 | ) -> None: 1916 | """Set up a test component platform.""" 1917 | async_add_entities(entities) 1918 | 1919 | platform = MockPlatform( 1920 | async_setup_platform=_async_setup_platform, 1921 | ) 1922 | 1923 | # avoid creating config entry setup if not needed 1924 | if from_config_entry: 1925 | 1926 | async def _async_setup_entry( 1927 | hass: HomeAssistant, 1928 | entry: ConfigEntry, 1929 | async_add_entities: AddConfigEntryEntitiesCallback, 1930 | ) -> None: 1931 | """Set up a test component platform.""" 1932 | async_add_entities(entities) 1933 | 1934 | platform.async_setup_entry = _async_setup_entry 1935 | platform.async_setup_platform = None 1936 | 1937 | mock_platform(hass, f"test.{domain}", platform, built_in=built_in) 1938 | return platform 1939 | 1940 | 1941 | async def snapshot_platform( 1942 | hass: HomeAssistant, 1943 | entity_registry: er.EntityRegistry, 1944 | snapshot: SnapshotAssertion, 1945 | config_entry_id: str, 1946 | ) -> None: 1947 | """Snapshot a platform.""" 1948 | entity_entries = er.async_entries_for_config_entry(entity_registry, config_entry_id) 1949 | assert entity_entries 1950 | assert len({entity_entry.domain for entity_entry in entity_entries}) == 1, ( 1951 | "Please limit the loaded platforms to 1 platform." 1952 | ) 1953 | for entity_entry in entity_entries: 1954 | assert entity_entry == snapshot(name=f"{entity_entry.entity_id}-entry") 1955 | assert entity_entry.disabled_by is None, "Please enable all entities." 1956 | state = hass.states.get(entity_entry.entity_id) 1957 | assert state, f"State not found for {entity_entry.entity_id}" 1958 | assert state == snapshot(name=f"{entity_entry.entity_id}-state") 1959 | 1960 | 1961 | @lru_cache 1962 | def get_quality_scale(integration: str) -> dict[str, QualityScaleStatus]: 1963 | """Load quality scale for integration.""" 1964 | quality_scale_file = pathlib.Path( 1965 | f"homeassistant/components/{integration}/quality_scale.yaml" 1966 | ) 1967 | if not quality_scale_file.exists(): 1968 | return {} 1969 | raw = load_yaml_dict(quality_scale_file) 1970 | return { 1971 | rule: ( 1972 | QualityScaleStatus(details) 1973 | if isinstance(details, str) 1974 | else QualityScaleStatus(details["status"]) 1975 | ) 1976 | for rule, details in raw["rules"].items() 1977 | } 1978 | 1979 | 1980 | def get_schema_suggested_value(schema: vol.Schema, key: str) -> Any | None: 1981 | """Get suggested value for key in voluptuous schema.""" 1982 | for schema_key in schema: 1983 | if schema_key == key: 1984 | if ( 1985 | schema_key.description is None 1986 | or "suggested_value" not in schema_key.description 1987 | ): 1988 | return None 1989 | return schema_key.description["suggested_value"] 1990 | return None 1991 | 1992 | 1993 | def get_sensor_display_state( 1994 | hass: HomeAssistant, entity_registry: er.EntityRegistry, entity_id: str 1995 | ) -> str: 1996 | """Return the state rounded for presentation.""" 1997 | state = hass.states.get(entity_id) 1998 | assert state 1999 | value = state.state 2000 | 2001 | entity_entry = entity_registry.async_get(entity_id) 2002 | if entity_entry is None: 2003 | return value 2004 | 2005 | if ( 2006 | precision := entity_entry.options.get("sensor", {}).get( 2007 | "suggested_display_precision" 2008 | ) 2009 | ) is None: 2010 | return value 2011 | 2012 | with suppress(TypeError, ValueError): 2013 | numerical_value = float(value) 2014 | value = f"{numerical_value:z.{precision}f}" 2015 | return value 2016 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The tests for components. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/diagnostics/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the Diagnostics integration. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from http import HTTPStatus 8 | from typing import cast 9 | 10 | from homeassistant.config_entries import ConfigEntry 11 | from homeassistant.core import HomeAssistant 12 | from homeassistant.helpers.device_registry import DeviceEntry 13 | from homeassistant.setup import async_setup_component 14 | from homeassistant.util.json import JsonObjectType 15 | 16 | from pytest_homeassistant_custom_component.typing import ClientSessionGenerator 17 | 18 | 19 | async def _get_diagnostics_for_config_entry( 20 | hass: HomeAssistant, 21 | hass_client: ClientSessionGenerator, 22 | config_entry: ConfigEntry, 23 | ) -> JsonObjectType: 24 | """Return the diagnostics config entry for the specified domain.""" 25 | assert await async_setup_component(hass, "diagnostics", {}) 26 | await hass.async_block_till_done() 27 | 28 | client = await hass_client() 29 | response = await client.get( 30 | f"/api/diagnostics/config_entry/{config_entry.entry_id}" 31 | ) 32 | assert response.status == HTTPStatus.OK 33 | return cast(JsonObjectType, await response.json()) 34 | 35 | 36 | async def get_diagnostics_for_config_entry( 37 | hass: HomeAssistant, 38 | hass_client: ClientSessionGenerator, 39 | config_entry: ConfigEntry, 40 | ) -> JsonObjectType: 41 | """Return the diagnostics config entry for the specified domain.""" 42 | data = await _get_diagnostics_for_config_entry(hass, hass_client, config_entry) 43 | return cast(JsonObjectType, data["data"]) 44 | 45 | 46 | async def _get_diagnostics_for_device( 47 | hass: HomeAssistant, 48 | hass_client: ClientSessionGenerator, 49 | config_entry: ConfigEntry, 50 | device: DeviceEntry, 51 | ) -> JsonObjectType: 52 | """Return the diagnostics for the specified device.""" 53 | assert await async_setup_component(hass, "diagnostics", {}) 54 | 55 | client = await hass_client() 56 | response = await client.get( 57 | f"/api/diagnostics/config_entry/{config_entry.entry_id}/device/{device.id}" 58 | ) 59 | assert response.status == HTTPStatus.OK 60 | return cast(JsonObjectType, await response.json()) 61 | 62 | 63 | async def get_diagnostics_for_device( 64 | hass: HomeAssistant, 65 | hass_client: ClientSessionGenerator, 66 | config_entry: ConfigEntry, 67 | device: DeviceEntry, 68 | ) -> JsonObjectType: 69 | """Return the diagnostics for the specified device.""" 70 | data = await _get_diagnostics_for_device(hass, hass_client, config_entry, device) 71 | return cast(JsonObjectType, data["data"]) 72 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/recorder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for Recorder component. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | import pytest 8 | 9 | pytest.register_assert_rewrite("tests.components.recorder.common") 10 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/recorder/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common test utils for working with recorder. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import asyncio 10 | from collections.abc import Iterable, Iterator 11 | from contextlib import contextmanager 12 | from dataclasses import dataclass 13 | from datetime import datetime, timedelta 14 | from functools import partial 15 | import importlib 16 | import sys 17 | import time 18 | from typing import Any, Literal, cast 19 | from unittest.mock import MagicMock, patch, sentinel 20 | 21 | from freezegun import freeze_time 22 | from sqlalchemy import create_engine, event as sqlalchemy_event 23 | from sqlalchemy.orm.session import Session 24 | 25 | from homeassistant import core as ha 26 | from homeassistant.components import recorder 27 | from homeassistant.components.recorder import ( 28 | Recorder, 29 | core, 30 | get_instance, 31 | migration, 32 | statistics, 33 | ) 34 | from homeassistant.components.recorder.db_schema import ( 35 | Events, 36 | EventTypes, 37 | RecorderRuns, 38 | States, 39 | StatesMeta, 40 | ) 41 | from homeassistant.components.recorder.tasks import RecorderTask, StatisticsTask 42 | from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass 43 | from homeassistant.const import DEGREE, UnitOfTemperature 44 | from homeassistant.core import Event, HomeAssistant, State 45 | from homeassistant.helpers import recorder as recorder_helper 46 | from homeassistant.util import dt as dt_util 47 | 48 | from . import db_schema_0 49 | 50 | DEFAULT_PURGE_TASKS = 3 51 | CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" 52 | 53 | 54 | @dataclass 55 | class BlockRecorderTask(RecorderTask): 56 | """A task to block the recorder for testing only.""" 57 | 58 | event: asyncio.Event 59 | seconds: float 60 | 61 | def run(self, instance: Recorder) -> None: 62 | """Block the recorders event loop.""" 63 | instance.hass.loop.call_soon_threadsafe(self.event.set) 64 | time.sleep(self.seconds) 65 | 66 | 67 | @dataclass 68 | class ForceReturnConnectionToPool(RecorderTask): 69 | """Force return connection to pool.""" 70 | 71 | def run(self, instance: Recorder) -> None: 72 | """Handle the task.""" 73 | instance.event_session.commit() 74 | 75 | 76 | async def async_block_recorder(hass: HomeAssistant, seconds: float) -> None: 77 | """Block the recorders event loop for testing. 78 | 79 | Returns as soon as the recorder has started the block. 80 | 81 | Does not wait for the block to finish. 82 | """ 83 | event = asyncio.Event() 84 | get_instance(hass).queue_task(BlockRecorderTask(event, seconds)) 85 | await event.wait() 86 | 87 | 88 | async def async_wait_recorder(hass: HomeAssistant) -> bool: 89 | """Wait for recorder to initialize and return connection status.""" 90 | return await hass.data[recorder_helper.DATA_RECORDER].db_connected 91 | 92 | 93 | def get_start_time(start: datetime) -> datetime: 94 | """Calculate a valid start time for statistics.""" 95 | start_minutes = start.minute - start.minute % 5 96 | return start.replace(minute=start_minutes, second=0, microsecond=0) 97 | 98 | 99 | def do_adhoc_statistics(hass: HomeAssistant, **kwargs: Any) -> None: 100 | """Trigger an adhoc statistics run.""" 101 | if not (start := kwargs.get("start")): 102 | start = statistics.get_start_time() 103 | elif (start.minute % 5) != 0 or start.second != 0 or start.microsecond != 0: 104 | raise ValueError(f"Statistics must start on 5 minute boundary got {start}") 105 | get_instance(hass).queue_task(StatisticsTask(start, False)) 106 | 107 | 108 | def wait_recording_done(hass: HomeAssistant) -> None: 109 | """Block till recording is done.""" 110 | hass.block_till_done() 111 | trigger_db_commit(hass) 112 | hass.block_till_done() 113 | recorder.get_instance(hass).block_till_done() 114 | hass.block_till_done() 115 | 116 | 117 | def trigger_db_commit(hass: HomeAssistant) -> None: 118 | """Force the recorder to commit.""" 119 | recorder.get_instance(hass)._async_commit(dt_util.utcnow()) 120 | 121 | 122 | async def async_wait_recording_done(hass: HomeAssistant) -> None: 123 | """Async wait until recording is done.""" 124 | await hass.async_block_till_done() 125 | async_trigger_db_commit(hass) 126 | await hass.async_block_till_done() 127 | await async_recorder_block_till_done(hass) 128 | await hass.async_block_till_done() 129 | 130 | 131 | async def async_wait_purge_done( 132 | hass: HomeAssistant, max_number: int | None = None 133 | ) -> None: 134 | """Wait for max number of purge events. 135 | 136 | Because a purge may insert another PurgeTask into 137 | the queue after the WaitTask finishes, we need up to 138 | a maximum number of WaitTasks that we will put into the 139 | queue. 140 | """ 141 | if not max_number: 142 | max_number = DEFAULT_PURGE_TASKS 143 | for _ in range(max_number + 1): 144 | await async_wait_recording_done(hass) 145 | 146 | 147 | @ha.callback 148 | def async_trigger_db_commit(hass: HomeAssistant) -> None: 149 | """Force the recorder to commit. Async friendly.""" 150 | recorder.get_instance(hass)._async_commit(dt_util.utcnow()) 151 | 152 | 153 | async def async_recorder_block_till_done(hass: HomeAssistant) -> None: 154 | """Non blocking version of recorder.block_till_done().""" 155 | await hass.async_add_executor_job(recorder.get_instance(hass).block_till_done) 156 | 157 | 158 | def corrupt_db_file(test_db_file): 159 | """Corrupt an sqlite3 database file.""" 160 | with open(test_db_file, "w+", encoding="utf8") as fhandle: 161 | fhandle.seek(200) 162 | fhandle.write("I am a corrupt db" * 100) 163 | 164 | 165 | def create_engine_test(*args, **kwargs): 166 | """Test version of create_engine that initializes with old schema. 167 | 168 | This simulates an existing db with the old schema. 169 | """ 170 | engine = create_engine(*args, **kwargs) 171 | db_schema_0.Base.metadata.create_all(engine) 172 | return engine 173 | 174 | 175 | def run_information_with_session( 176 | session: Session, point_in_time: datetime | None = None 177 | ) -> RecorderRuns | None: 178 | """Return information about current run from the database.""" 179 | recorder_runs = RecorderRuns 180 | 181 | query = session.query(recorder_runs) 182 | if point_in_time: 183 | query = query.filter( 184 | (recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time) 185 | ) 186 | 187 | if (res := query.first()) is not None: 188 | session.expunge(res) 189 | return cast(RecorderRuns, res) 190 | return res 191 | 192 | 193 | def statistics_during_period( 194 | hass: HomeAssistant, 195 | start_time: datetime, 196 | end_time: datetime | None = None, 197 | statistic_ids: set[str] | None = None, 198 | period: Literal["5minute", "day", "hour", "week", "month"] = "hour", 199 | units: dict[str, str] | None = None, 200 | types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]] 201 | | None = None, 202 | ) -> dict[str, list[dict[str, Any]]]: 203 | """Call statistics_during_period with defaults for simpler ...""" 204 | if statistic_ids is not None and not isinstance(statistic_ids, set): 205 | statistic_ids = set(statistic_ids) 206 | if types is None: 207 | types = {"last_reset", "max", "mean", "min", "state", "sum"} 208 | return statistics.statistics_during_period( 209 | hass, start_time, end_time, statistic_ids, period, units, types 210 | ) 211 | 212 | 213 | def assert_states_equal_without_context(state: State, other: State) -> None: 214 | """Assert that two states are equal, ignoring context.""" 215 | assert_states_equal_without_context_and_last_changed(state, other) 216 | assert state.last_changed == other.last_changed 217 | assert state.last_reported == other.last_reported 218 | 219 | 220 | def assert_states_equal_without_context_and_last_changed( 221 | state: State, other: State 222 | ) -> None: 223 | """Assert that two states are equal, ignoring context and last_changed.""" 224 | assert state.state == other.state 225 | assert state.attributes == other.attributes 226 | assert state.last_updated == other.last_updated 227 | 228 | 229 | def assert_multiple_states_equal_without_context_and_last_changed( 230 | states: Iterable[State], others: Iterable[State] 231 | ) -> None: 232 | """Assert that multiple states are equal, ignoring context and last_changed.""" 233 | states_list = list(states) 234 | others_list = list(others) 235 | assert len(states_list) == len(others_list) 236 | for i, state in enumerate(states_list): 237 | assert_states_equal_without_context_and_last_changed(state, others_list[i]) 238 | 239 | 240 | def assert_multiple_states_equal_without_context( 241 | states: Iterable[State], others: Iterable[State] 242 | ) -> None: 243 | """Assert that multiple states are equal, ignoring context.""" 244 | states_list = list(states) 245 | others_list = list(others) 246 | assert len(states_list) == len(others_list) 247 | for i, state in enumerate(states_list): 248 | assert_states_equal_without_context(state, others_list[i]) 249 | 250 | 251 | def assert_events_equal_without_context(event: Event, other: Event) -> None: 252 | """Assert that two events are equal, ignoring context.""" 253 | assert event.data == other.data 254 | assert event.event_type == other.event_type 255 | assert event.origin == other.origin 256 | assert event.time_fired == other.time_fired 257 | 258 | 259 | def assert_dict_of_states_equal_without_context( 260 | states: dict[str, list[State]], others: dict[str, list[State]] 261 | ) -> None: 262 | """Assert that two dicts of states are equal, ignoring context.""" 263 | assert len(states) == len(others) 264 | for entity_id, state in states.items(): 265 | assert_multiple_states_equal_without_context(state, others[entity_id]) 266 | 267 | 268 | def assert_dict_of_states_equal_without_context_and_last_changed( 269 | states: dict[str, list[State]], others: dict[str, list[State]] 270 | ) -> None: 271 | """Assert that two dicts of states are equal, ignoring context and last_changed.""" 272 | assert len(states) == len(others) 273 | for entity_id, state in states.items(): 274 | assert_multiple_states_equal_without_context_and_last_changed( 275 | state, others[entity_id] 276 | ) 277 | 278 | 279 | async def async_record_states( 280 | hass: HomeAssistant, 281 | ) -> tuple[datetime, datetime, dict[str, list[State | None]]]: 282 | """Record some test states.""" 283 | return await hass.async_add_executor_job(record_states, hass) 284 | 285 | 286 | def record_states( 287 | hass: HomeAssistant, 288 | ) -> tuple[datetime, datetime, dict[str, list[State | None]]]: 289 | """Record some test states. 290 | 291 | We inject a bunch of state updates temperature sensors. 292 | """ 293 | mp = "media_player.test" 294 | sns1 = "sensor.test1" 295 | sns2 = "sensor.test2" 296 | sns3 = "sensor.test3" 297 | sns4 = "sensor.test4" 298 | sns5 = "sensor.wind_direction" 299 | sns1_attr = { 300 | "device_class": "temperature", 301 | "state_class": "measurement", 302 | "unit_of_measurement": UnitOfTemperature.CELSIUS, 303 | } 304 | sns2_attr = { 305 | "device_class": "humidity", 306 | "state_class": "measurement", 307 | "unit_of_measurement": "%", 308 | } 309 | sns3_attr = {"device_class": "temperature"} 310 | sns4_attr = {} 311 | sns5_attr = { 312 | "device_class": SensorDeviceClass.WIND_DIRECTION, 313 | "state_class": SensorStateClass.MEASUREMENT_ANGLE, 314 | "unit_of_measurement": DEGREE, 315 | } 316 | 317 | def set_state(entity_id, state, **kwargs): 318 | """Set the state.""" 319 | hass.states.set(entity_id, state, **kwargs) 320 | wait_recording_done(hass) 321 | return hass.states.get(entity_id) 322 | 323 | zero = get_start_time(dt_util.utcnow()) 324 | one = zero + timedelta(seconds=1 * 5) 325 | two = one + timedelta(seconds=15 * 5) 326 | three = two + timedelta(seconds=30 * 5) 327 | four = three + timedelta(seconds=14 * 5) 328 | 329 | states = {mp: [], sns1: [], sns2: [], sns3: [], sns4: [], sns5: []} 330 | with freeze_time(one) as freezer: 331 | states[mp].append( 332 | set_state(mp, "idle", attributes={"media_title": str(sentinel.mt1)}) 333 | ) 334 | states[sns1].append(set_state(sns1, "10", attributes=sns1_attr)) 335 | states[sns2].append(set_state(sns2, "10", attributes=sns2_attr)) 336 | states[sns3].append(set_state(sns3, "10", attributes=sns3_attr)) 337 | states[sns4].append(set_state(sns4, "10", attributes=sns4_attr)) 338 | states[sns5].append(set_state(sns5, "10", attributes=sns5_attr)) 339 | 340 | freezer.move_to(one + timedelta(microseconds=1)) 341 | states[mp].append( 342 | set_state(mp, "YouTube", attributes={"media_title": str(sentinel.mt2)}) 343 | ) 344 | 345 | freezer.move_to(two) 346 | states[sns1].append(set_state(sns1, "15", attributes=sns1_attr)) 347 | states[sns2].append(set_state(sns2, "15", attributes=sns2_attr)) 348 | states[sns3].append(set_state(sns3, "15", attributes=sns3_attr)) 349 | states[sns4].append(set_state(sns4, "15", attributes=sns4_attr)) 350 | states[sns5].append(set_state(sns5, "350", attributes=sns5_attr)) 351 | 352 | freezer.move_to(three) 353 | states[sns1].append(set_state(sns1, "20", attributes=sns1_attr)) 354 | states[sns2].append(set_state(sns2, "20", attributes=sns2_attr)) 355 | states[sns3].append(set_state(sns3, "20", attributes=sns3_attr)) 356 | states[sns4].append(set_state(sns4, "20", attributes=sns4_attr)) 357 | states[sns5].append(set_state(sns5, "5", attributes=sns5_attr)) 358 | 359 | return zero, four, states 360 | 361 | 362 | def convert_pending_states_to_meta(instance: Recorder, session: Session) -> None: 363 | """Convert pending states to use states_metadata.""" 364 | entity_ids: set[str] = set() 365 | states: set[States] = set() 366 | states_meta_objects: dict[str, StatesMeta] = {} 367 | for session_object in session: 368 | if isinstance(session_object, States): 369 | entity_ids.add(session_object.entity_id) 370 | states.add(session_object) 371 | 372 | entity_id_to_metadata_ids = instance.states_meta_manager.get_many( 373 | entity_ids, session, True 374 | ) 375 | 376 | for state in states: 377 | entity_id = state.entity_id 378 | state.entity_id = None 379 | state.attributes = None 380 | state.event_id = None 381 | if metadata_id := entity_id_to_metadata_ids.get(entity_id): 382 | state.metadata_id = metadata_id 383 | continue 384 | if entity_id not in states_meta_objects: 385 | states_meta_objects[entity_id] = StatesMeta(entity_id=entity_id) 386 | state.states_meta_rel = states_meta_objects[entity_id] 387 | 388 | 389 | def convert_pending_events_to_event_types(instance: Recorder, session: Session) -> None: 390 | """Convert pending events to use event_type_ids.""" 391 | event_types: set[str] = set() 392 | events: set[Events] = set() 393 | event_types_objects: dict[str, EventTypes] = {} 394 | for session_object in session: 395 | if isinstance(session_object, Events): 396 | event_types.add(session_object.event_type) 397 | events.add(session_object) 398 | 399 | event_type_to_event_type_ids = instance.event_type_manager.get_many( 400 | event_types, session, True 401 | ) 402 | manually_added_event_types: list[str] = [] 403 | 404 | for event in events: 405 | event_type = event.event_type 406 | event.event_type = None 407 | event.event_data = None 408 | event.origin = None 409 | if event_type_id := event_type_to_event_type_ids.get(event_type): 410 | event.event_type_id = event_type_id 411 | continue 412 | if event_type not in event_types_objects: 413 | event_types_objects[event_type] = EventTypes(event_type=event_type) 414 | manually_added_event_types.append(event_type) 415 | event.event_type_rel = event_types_objects[event_type] 416 | 417 | for event_type in manually_added_event_types: 418 | instance.event_type_manager._non_existent_event_types.pop(event_type, None) 419 | 420 | 421 | def create_engine_test_for_schema_version_postfix( 422 | *args, schema_version_postfix: str, **kwargs 423 | ): 424 | """Test version of create_engine that initializes with old schema. 425 | 426 | This simulates an existing db with the old schema. 427 | """ 428 | schema_module = get_schema_module_path(schema_version_postfix) 429 | importlib.import_module(schema_module) 430 | old_db_schema = sys.modules[schema_module] 431 | instance: Recorder | None = None 432 | if "hass" in kwargs: 433 | hass: HomeAssistant = kwargs.pop("hass") 434 | instance = recorder.get_instance(hass) 435 | engine = create_engine(*args, **kwargs) 436 | if instance is not None: 437 | instance = recorder.get_instance(hass) 438 | instance.engine = engine 439 | sqlalchemy_event.listen(engine, "connect", instance._setup_recorder_connection) 440 | old_db_schema.Base.metadata.create_all(engine) 441 | with Session(engine) as session: 442 | session.add( 443 | recorder.db_schema.StatisticsRuns(start=statistics.get_start_time()) 444 | ) 445 | session.add( 446 | recorder.db_schema.SchemaChanges( 447 | schema_version=old_db_schema.SCHEMA_VERSION 448 | ) 449 | ) 450 | session.commit() 451 | return engine 452 | 453 | 454 | def get_schema_module_path(schema_version_postfix: str) -> str: 455 | """Return the path to the schema module.""" 456 | return f"...components.recorder.db_schema_{schema_version_postfix}" 457 | 458 | 459 | @contextmanager 460 | def old_db_schema(hass: HomeAssistant, schema_version_postfix: str) -> Iterator[None]: 461 | """Fixture to initialize the db with the old schema.""" 462 | schema_module = get_schema_module_path(schema_version_postfix) 463 | importlib.import_module(schema_module) 464 | old_db_schema = sys.modules[schema_module] 465 | 466 | with ( 467 | patch.object(recorder, "db_schema", old_db_schema), 468 | patch.object(migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION), 469 | patch.object(migration, "non_live_data_migration_needed", return_value=False), 470 | patch.object(core, "StatesMeta", old_db_schema.StatesMeta), 471 | patch.object(core, "EventTypes", old_db_schema.EventTypes), 472 | patch.object(core, "EventData", old_db_schema.EventData), 473 | patch.object(core, "States", old_db_schema.States), 474 | patch.object(core, "Events", old_db_schema.Events), 475 | patch.object(core, "StateAttributes", old_db_schema.StateAttributes), 476 | patch( 477 | CREATE_ENGINE_TARGET, 478 | new=partial( 479 | create_engine_test_for_schema_version_postfix, 480 | hass=hass, 481 | schema_version_postfix=schema_version_postfix, 482 | ), 483 | ), 484 | ): 485 | yield 486 | 487 | 488 | async def async_attach_db_engine(hass: HomeAssistant) -> None: 489 | """Attach a database engine to the recorder.""" 490 | instance = recorder.get_instance(hass) 491 | 492 | def _mock_setup_recorder_connection(): 493 | with instance.engine.connect() as connection: 494 | instance._setup_recorder_connection( 495 | connection._dbapi_connection, MagicMock() 496 | ) 497 | 498 | await instance.async_add_executor_job(_mock_setup_recorder_connection) 499 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/recorder/db_schema_0.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models for SQLAlchemy. 3 | 4 | This file contains the original models definitions before schema tracking was 5 | implemented. It is used to test the schema migration logic. 6 | 7 | 8 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 9 | """ 10 | 11 | import json 12 | import logging 13 | 14 | from sqlalchemy import ( 15 | Boolean, 16 | Column, 17 | DateTime, 18 | ForeignKey, 19 | Index, 20 | Integer, 21 | String, 22 | Text, 23 | distinct, 24 | ) 25 | from sqlalchemy.orm import declarative_base 26 | from sqlalchemy.orm.session import Session 27 | 28 | from homeassistant.core import Event, EventOrigin, State, split_entity_id 29 | from homeassistant.helpers.json import JSONEncoder 30 | from homeassistant.util import dt as dt_util 31 | 32 | # SQLAlchemy Schema 33 | Base = declarative_base() 34 | 35 | _LOGGER = logging.getLogger(__name__) 36 | 37 | 38 | class Events(Base): # type: ignore[valid-type,misc] 39 | """Event history data.""" 40 | 41 | __tablename__ = "events" 42 | event_id = Column(Integer, primary_key=True) 43 | event_type = Column(String(32), index=True) 44 | event_data = Column(Text) 45 | origin = Column(String(32)) 46 | time_fired = Column(DateTime(timezone=True)) 47 | created = Column(DateTime(timezone=True), default=dt_util.utcnow) 48 | 49 | @staticmethod 50 | def from_event(event): 51 | """Create an event database object from a native event.""" 52 | return Events( 53 | event_type=event.event_type, 54 | event_data=json.dumps(event.data, cls=JSONEncoder), 55 | origin=str(event.origin), 56 | time_fired=event.time_fired, 57 | ) 58 | 59 | def to_native(self): 60 | """Convert to a natve HA Event.""" 61 | try: 62 | return Event( 63 | self.event_type, 64 | json.loads(self.event_data), 65 | EventOrigin(self.origin), 66 | _process_timestamp(self.time_fired), 67 | ) 68 | except ValueError: 69 | # When json.loads fails 70 | _LOGGER.exception("Error converting to event: %s", self) 71 | return None 72 | 73 | 74 | class States(Base): # type: ignore[valid-type,misc] 75 | """State change history.""" 76 | 77 | __tablename__ = "states" 78 | state_id = Column(Integer, primary_key=True) 79 | domain = Column(String(64)) 80 | entity_id = Column(String(255)) 81 | state = Column(String(255)) 82 | attributes = Column(Text) 83 | event_id = Column(Integer, ForeignKey("events.event_id")) 84 | last_changed = Column(DateTime(timezone=True), default=dt_util.utcnow) 85 | last_updated = Column(DateTime(timezone=True), default=dt_util.utcnow) 86 | created = Column(DateTime(timezone=True), default=dt_util.utcnow) 87 | 88 | __table_args__ = ( 89 | Index("states__state_changes", "last_changed", "last_updated", "entity_id"), 90 | Index("states__significant_changes", "domain", "last_updated", "entity_id"), 91 | ) 92 | 93 | @staticmethod 94 | def from_event(event): 95 | """Create object from a state_changed event.""" 96 | entity_id = event.data["entity_id"] 97 | state = event.data.get("new_state") 98 | 99 | dbstate = States(entity_id=entity_id) 100 | 101 | # State got deleted 102 | if state is None: 103 | dbstate.state = "" 104 | dbstate.domain = split_entity_id(entity_id)[0] 105 | dbstate.attributes = "{}" 106 | dbstate.last_changed = event.time_fired 107 | dbstate.last_updated = event.time_fired 108 | else: 109 | dbstate.domain = state.domain 110 | dbstate.state = state.state 111 | dbstate.attributes = json.dumps(dict(state.attributes), cls=JSONEncoder) 112 | dbstate.last_changed = state.last_changed 113 | dbstate.last_updated = state.last_updated 114 | 115 | return dbstate 116 | 117 | def to_native(self): 118 | """Convert to an HA state object.""" 119 | try: 120 | return State( 121 | self.entity_id, 122 | self.state, 123 | json.loads(self.attributes), 124 | _process_timestamp(self.last_changed), 125 | _process_timestamp(self.last_updated), 126 | ) 127 | except ValueError: 128 | # When json.loads fails 129 | _LOGGER.exception("Error converting row to state: %s", self) 130 | return None 131 | 132 | 133 | class RecorderRuns(Base): # type: ignore[valid-type,misc] 134 | """Representation of recorder run.""" 135 | 136 | __tablename__ = "recorder_runs" 137 | run_id = Column(Integer, primary_key=True) 138 | start = Column(DateTime(timezone=True), default=dt_util.utcnow) 139 | end = Column(DateTime(timezone=True)) 140 | closed_incorrect = Column(Boolean, default=False) 141 | created = Column(DateTime(timezone=True), default=dt_util.utcnow) 142 | 143 | def entity_ids(self, point_in_time=None): 144 | """Return the entity ids that existed in this run. 145 | 146 | Specify point_in_time if you want to know which existed at that point 147 | in time inside the run. 148 | """ 149 | session = Session.object_session(self) 150 | 151 | assert session is not None, "RecorderRuns need to be persisted" 152 | 153 | query = session.query(distinct(States.entity_id)).filter( 154 | States.last_updated >= self.start 155 | ) 156 | 157 | if point_in_time is not None: 158 | query = query.filter(States.last_updated < point_in_time) 159 | elif self.end is not None: 160 | query = query.filter(States.last_updated < self.end) 161 | 162 | return [row[0] for row in query] 163 | 164 | def to_native(self): 165 | """Return self, native format is this model.""" 166 | return self 167 | 168 | 169 | def _process_timestamp(ts): 170 | """Process a timestamp into datetime object.""" 171 | if ts is None: 172 | return None 173 | if ts.tzinfo is None: 174 | return ts.replace(tzinfo=dt_util.UTC) 175 | return dt_util.as_utc(ts) 176 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants used by Home Assistant components. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | from typing import TYPE_CHECKING, Final 7 | MAJOR_VERSION: Final = 2025 8 | MINOR_VERSION: Final = 6 9 | PATCH_VERSION: Final = "0b5" 10 | __short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}" 11 | __version__: Final = f"{__short_version__}.{PATCH_VERSION}" 12 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/ignore_uncaught_exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of tests that have uncaught exceptions today. Will be shrunk over time. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | IGNORE_UNCAUGHT_EXCEPTIONS = [ 8 | ( 9 | # This test explicitly throws an uncaught exception 10 | # and should not be removed. 11 | ".test_runner", 12 | "test_unhandled_exception_traceback", 13 | ), 14 | ( 15 | # This test explicitly throws an uncaught exception 16 | # and should not be removed. 17 | ".helpers.test_event", 18 | "test_track_point_in_time_repr", 19 | ), 20 | ( 21 | # This test explicitly throws an uncaught exception 22 | # and should not be removed. 23 | ".test_config_entries", 24 | "test_config_entry_unloaded_during_platform_setups", 25 | ), 26 | ( 27 | # This test explicitly throws an uncaught exception 28 | # and should not be removed. 29 | ".test_config_entries", 30 | "test_config_entry_unloaded_during_platform_setup", 31 | ), 32 | ( 33 | "test_homeassistant_bridge", 34 | "test_homeassistant_bridge_fan_setup", 35 | ), 36 | ( 37 | ".components.owntracks.test_device_tracker", 38 | "test_mobile_multiple_async_enter_exit", 39 | ), 40 | ( 41 | ".components.smartthings.test_init", 42 | "test_event_handler_dispatches_updated_devices", 43 | ), 44 | ( 45 | ".components.unifi.test_controller", 46 | "test_wireless_client_event_calls_update_wireless_devices", 47 | ), 48 | (".components.iaqualink.test_config_flow", "test_with_invalid_credentials"), 49 | (".components.iaqualink.test_config_flow", "test_with_existing_config"), 50 | ] 51 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/patch_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch JSON related functions. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import functools 10 | from typing import Any 11 | from unittest import mock 12 | 13 | import orjson 14 | 15 | from homeassistant.helpers import json as json_helper 16 | 17 | real_json_encoder_default = json_helper.json_encoder_default 18 | 19 | mock_objects = [] 20 | 21 | 22 | def json_encoder_default(obj: Any) -> Any: 23 | """Convert Home Assistant objects. 24 | 25 | Hand other objects to the original method. 26 | """ 27 | if isinstance(obj, mock.Base): 28 | mock_objects.append(obj) 29 | raise TypeError(f"Attempting to serialize mock object {obj}") 30 | return real_json_encoder_default(obj) 31 | 32 | 33 | json_helper.json_encoder_default = json_encoder_default 34 | json_helper.json_bytes = functools.partial( 35 | orjson.dumps, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default 36 | ) 37 | json_helper.json_bytes_sorted = functools.partial( 38 | orjson.dumps, 39 | option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS, 40 | default=json_encoder_default, 41 | ) 42 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/patch_recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch recorder related functions. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from contextlib import contextmanager 10 | import sys 11 | 12 | # Patch recorder util session scope 13 | from homeassistant.helpers import recorder as recorder_helper 14 | 15 | # Make sure homeassistant.components.recorder.util is not already imported 16 | assert "homeassistant.components.recorder.util" not in sys.modules 17 | 18 | real_session_scope = recorder_helper.session_scope 19 | 20 | 21 | @contextmanager 22 | def _session_scope_wrapper(*args, **kwargs): 23 | """Make session_scope patchable. 24 | 25 | This function will be imported by recorder modules. 26 | """ 27 | with real_session_scope(*args, **kwargs) as ses: 28 | yield ses 29 | 30 | 31 | recorder_helper.session_scope = _session_scope_wrapper 32 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/patch_time.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch time related functions. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import datetime 10 | import time 11 | 12 | import freezegun 13 | 14 | 15 | def ha_datetime_to_fakedatetime(datetime) -> freezegun.api.FakeDatetime: # type: ignore[name-defined] 16 | """Convert datetime to FakeDatetime. 17 | 18 | Modified to include https://github.com/spulec/freezegun/pull/424. 19 | """ 20 | return freezegun.api.FakeDatetime( # type: ignore[attr-defined] 21 | datetime.year, 22 | datetime.month, 23 | datetime.day, 24 | datetime.hour, 25 | datetime.minute, 26 | datetime.second, 27 | datetime.microsecond, 28 | datetime.tzinfo, 29 | fold=datetime.fold, 30 | ) 31 | 32 | 33 | class HAFakeDatetime(freezegun.api.FakeDatetime): # type: ignore[name-defined] 34 | """Modified to include https://github.com/spulec/freezegun/pull/424.""" 35 | 36 | @classmethod 37 | def now(cls, tz=None): 38 | """Return frozen now.""" 39 | now = cls._time_to_freeze() or freezegun.api.real_datetime.now() 40 | if tz: 41 | result = tz.fromutc(now.replace(tzinfo=tz)) 42 | else: 43 | result = now 44 | 45 | # Add the _tz_offset only if it's non-zero to preserve fold 46 | if cls._tz_offset(): 47 | result += cls._tz_offset() 48 | 49 | return ha_datetime_to_fakedatetime(result) 50 | 51 | 52 | # Needed by Mashumaro 53 | datetime.HAFakeDatetime = HAFakeDatetime 54 | 55 | # Do not add any Home Assistant import here 56 | 57 | 58 | def _utcnow() -> datetime.datetime: 59 | """Make utcnow patchable by freezegun.""" 60 | return datetime.datetime.now(datetime.UTC) 61 | 62 | 63 | def _monotonic() -> float: 64 | """Make monotonic patchable by freezegun.""" 65 | return time.monotonic() 66 | 67 | 68 | # Before importing any other Home Assistant functionality, import and replace 69 | # partial dt_util.utcnow with a regular function which can be found by freezegun 70 | from homeassistant import util # noqa: E402 71 | from homeassistant.util import dt as dt_util # noqa: E402 72 | 73 | dt_util.utcnow = _utcnow # type: ignore[assignment] 74 | util.utcnow = _utcnow # type: ignore[assignment] 75 | 76 | 77 | # Import other Home Assistant functionality which we need to patch 78 | from homeassistant import runner # noqa: E402 79 | from homeassistant.helpers import event as event_helper # noqa: E402 80 | 81 | # Replace partial functions which are not found by freezegun 82 | event_helper.time_tracker_utcnow = _utcnow # type: ignore[assignment] 83 | 84 | # Replace bound methods which are not found by freezegun 85 | runner.monotonic = _monotonic # type: ignore[assignment] 86 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/syrupy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Home Assistant extension for Syrupy. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from contextlib import suppress 10 | import dataclasses 11 | from enum import IntFlag 12 | import json 13 | import os 14 | from pathlib import Path 15 | from typing import Any 16 | 17 | import attr 18 | import attrs 19 | import pytest 20 | from syrupy.constants import EXIT_STATUS_FAIL_UNUSED 21 | from syrupy.data import Snapshot, SnapshotCollection, SnapshotCollections 22 | from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension 23 | from syrupy.location import PyTestLocation 24 | from syrupy.report import SnapshotReport 25 | from syrupy.session import ItemStatus, SnapshotSession 26 | from syrupy.types import PropertyFilter, PropertyMatcher, PropertyPath, SerializableData 27 | from syrupy.utils import is_xdist_controller, is_xdist_worker 28 | import voluptuous as vol 29 | import voluptuous_serialize 30 | 31 | from homeassistant.config_entries import ConfigEntry 32 | from homeassistant.core import State 33 | from homeassistant.data_entry_flow import FlowResult 34 | from homeassistant.helpers import ( 35 | area_registry as ar, 36 | device_registry as dr, 37 | entity_registry as er, 38 | issue_registry as ir, 39 | ) 40 | 41 | 42 | class _ANY: 43 | """Represent any value.""" 44 | 45 | def __repr__(self) -> str: 46 | return "" 47 | 48 | 49 | ANY = _ANY() 50 | 51 | __all__ = ["HomeAssistantSnapshotExtension"] 52 | 53 | 54 | class AreaRegistryEntrySnapshot(dict): 55 | """Tiny wrapper to represent an area registry entry in snapshots.""" 56 | 57 | 58 | class ConfigEntrySnapshot(dict): 59 | """Tiny wrapper to represent a config entry in snapshots.""" 60 | 61 | 62 | class DeviceRegistryEntrySnapshot(dict): 63 | """Tiny wrapper to represent a device registry entry in snapshots.""" 64 | 65 | 66 | class EntityRegistryEntrySnapshot(dict): 67 | """Tiny wrapper to represent an entity registry entry in snapshots.""" 68 | 69 | 70 | class FlowResultSnapshot(dict): 71 | """Tiny wrapper to represent a flow result in snapshots.""" 72 | 73 | 74 | class IssueRegistryItemSnapshot(dict): 75 | """Tiny wrapper to represent an entity registry entry in snapshots.""" 76 | 77 | 78 | class StateSnapshot(dict): 79 | """Tiny wrapper to represent an entity state in snapshots.""" 80 | 81 | 82 | class HomeAssistantSnapshotSerializer(AmberDataSerializer): 83 | """Home Assistant snapshot serializer for Syrupy. 84 | 85 | Handles special cases for Home Assistant data structures. 86 | """ 87 | 88 | @classmethod 89 | def _serialize( 90 | cls, 91 | data: SerializableData, 92 | *, 93 | depth: int = 0, 94 | exclude: PropertyFilter | None = None, 95 | include: PropertyFilter | None = None, 96 | matcher: PropertyMatcher | None = None, 97 | path: PropertyPath = (), 98 | visited: set[Any] | None = None, 99 | ) -> str: 100 | """Pre-process data before serializing. 101 | 102 | This allows us to handle specific cases for Home Assistant data structures. 103 | """ 104 | if isinstance(data, State): 105 | serializable_data = cls._serializable_state(data) 106 | elif isinstance(data, ar.AreaEntry): 107 | serializable_data = cls._serializable_area_registry_entry(data) 108 | elif isinstance(data, dr.DeviceEntry): 109 | serializable_data = cls._serializable_device_registry_entry(data) 110 | elif isinstance(data, er.RegistryEntry): 111 | serializable_data = cls._serializable_entity_registry_entry(data) 112 | elif isinstance(data, ir.IssueEntry): 113 | serializable_data = cls._serializable_issue_registry_entry(data) 114 | elif isinstance(data, dict) and "flow_id" in data and "handler" in data: 115 | serializable_data = cls._serializable_flow_result(data) 116 | elif isinstance(data, dict) and set(data) == { 117 | "conversation_id", 118 | "response", 119 | "continue_conversation", 120 | }: 121 | serializable_data = cls._serializable_conversation_result(data) 122 | elif isinstance(data, vol.Schema): 123 | serializable_data = voluptuous_serialize.convert(data) 124 | elif isinstance(data, ConfigEntry): 125 | serializable_data = cls._serializable_config_entry(data) 126 | elif dataclasses.is_dataclass(type(data)): 127 | serializable_data = dataclasses.asdict(data) 128 | elif isinstance(data, IntFlag): 129 | # The repr of an enum.IntFlag has changed between Python 3.10 and 3.11 130 | # so we normalize it here. 131 | serializable_data = _IntFlagWrapper(data) 132 | else: 133 | serializable_data = data 134 | with suppress(TypeError): 135 | if attr.has(type(data)): 136 | serializable_data = attrs.asdict(data) 137 | 138 | return super()._serialize( 139 | serializable_data, 140 | depth=depth, 141 | exclude=exclude, 142 | include=include, 143 | matcher=matcher, 144 | path=path, 145 | visited=visited, 146 | ) 147 | 148 | @classmethod 149 | def _serializable_area_registry_entry(cls, data: ar.AreaEntry) -> SerializableData: 150 | """Prepare a Home Assistant area registry entry for serialization.""" 151 | serialized = AreaRegistryEntrySnapshot(dataclasses.asdict(data) | {"id": ANY}) 152 | serialized.pop("_json_repr") 153 | serialized.pop("_cache") 154 | return serialized 155 | 156 | @classmethod 157 | def _serializable_config_entry(cls, data: ConfigEntry) -> SerializableData: 158 | """Prepare a Home Assistant config entry for serialization.""" 159 | entry = ConfigEntrySnapshot(data.as_dict() | {"entry_id": ANY}) 160 | return cls._remove_created_and_modified_at(entry) 161 | 162 | @classmethod 163 | def _serializable_device_registry_entry( 164 | cls, data: dr.DeviceEntry 165 | ) -> SerializableData: 166 | """Prepare a Home Assistant device registry entry for serialization.""" 167 | serialized = DeviceRegistryEntrySnapshot( 168 | attrs.asdict(data) 169 | | { 170 | "config_entries": ANY, 171 | "config_entries_subentries": ANY, 172 | "id": ANY, 173 | } 174 | ) 175 | if serialized["via_device_id"] is not None: 176 | serialized["via_device_id"] = ANY 177 | if serialized["primary_config_entry"] is not None: 178 | serialized["primary_config_entry"] = ANY 179 | serialized.pop("_cache") 180 | return cls._remove_created_and_modified_at(serialized) 181 | 182 | @classmethod 183 | def _remove_created_and_modified_at( 184 | cls, data: SerializableData 185 | ) -> SerializableData: 186 | """Remove created_at and modified_at from the data.""" 187 | data.pop("created_at", None) 188 | data.pop("modified_at", None) 189 | return data 190 | 191 | @classmethod 192 | def _serializable_entity_registry_entry( 193 | cls, data: er.RegistryEntry 194 | ) -> SerializableData: 195 | """Prepare a Home Assistant entity registry entry for serialization.""" 196 | serialized = EntityRegistryEntrySnapshot( 197 | attrs.asdict(data) 198 | | { 199 | "config_entry_id": ANY, 200 | "config_subentry_id": ANY, 201 | "device_id": ANY, 202 | "id": ANY, 203 | "options": {k: dict(v) for k, v in data.options.items()}, 204 | } 205 | ) 206 | serialized.pop("categories") 207 | serialized.pop("_cache") 208 | return cls._remove_created_and_modified_at(serialized) 209 | 210 | @classmethod 211 | def _serializable_flow_result(cls, data: FlowResult) -> SerializableData: 212 | """Prepare a Home Assistant flow result for serialization.""" 213 | return FlowResultSnapshot(data | {"flow_id": ANY}) 214 | 215 | @classmethod 216 | def _serializable_conversation_result(cls, data: dict) -> SerializableData: 217 | """Prepare a Home Assistant conversation result for serialization.""" 218 | return data | {"conversation_id": ANY} 219 | 220 | @classmethod 221 | def _serializable_issue_registry_entry( 222 | cls, data: ir.IssueEntry 223 | ) -> SerializableData: 224 | """Prepare a Home Assistant issue registry entry for serialization.""" 225 | return IssueRegistryItemSnapshot(dataclasses.asdict(data) | {"created": ANY}) 226 | 227 | @classmethod 228 | def _serializable_state(cls, data: State) -> SerializableData: 229 | """Prepare a Home Assistant State for serialization.""" 230 | return StateSnapshot( 231 | data.as_dict() 232 | | { 233 | "context": ANY, 234 | "last_changed": ANY, 235 | "last_reported": ANY, 236 | "last_updated": ANY, 237 | } 238 | ) 239 | 240 | 241 | class _IntFlagWrapper: 242 | def __init__(self, flag: IntFlag) -> None: 243 | self._flag = flag 244 | 245 | def __repr__(self) -> str: 246 | # 3.10: 247 | # 3.11: 248 | # Syrupy: 249 | return f"<{self._flag.__class__.__name__}: {self._flag.value}>" 250 | 251 | 252 | class HomeAssistantSnapshotExtension(AmberSnapshotExtension): 253 | """Home Assistant extension for Syrupy.""" 254 | 255 | VERSION = "1" 256 | """Current version of serialization format. 257 | 258 | Need to be bumped when we change the HomeAssistantSnapshotSerializer. 259 | """ 260 | 261 | serializer_class: type[AmberDataSerializer] = HomeAssistantSnapshotSerializer 262 | 263 | @classmethod 264 | def dirname(cls, *, test_location: PyTestLocation) -> str: 265 | """Return the directory for the snapshot files. 266 | 267 | Syrupy, by default, uses the `__snapshosts__` directory in the same 268 | folder as the test file. For Home Assistant, this is changed to just 269 | `snapshots` in the same folder as the test file, to match our `fixtures` 270 | folder structure. 271 | """ 272 | test_dir = Path(test_location.filepath).parent 273 | return str(test_dir.joinpath("snapshots")) 274 | 275 | 276 | # Classes and Methods to override default finish behavior in syrupy 277 | # This is needed to handle the xdist plugin in pytest 278 | # The default implementation does not handle the xdist plugin 279 | # and will not work correctly when running tests in parallel 280 | # with pytest-xdist. 281 | # Temporary workaround until it is finalised inside syrupy 282 | # See https://github.com/syrupy-project/syrupy/pull/901 283 | 284 | 285 | class _FakePytestObject: 286 | """Fake object.""" 287 | 288 | def __init__(self, collected_item: dict[str, str]) -> None: 289 | """Initialise fake object.""" 290 | self.__module__ = collected_item["modulename"] 291 | self.__name__ = collected_item["methodname"] 292 | 293 | 294 | class _FakePytestItem: 295 | """Fake pytest.Item object.""" 296 | 297 | def __init__(self, collected_item: dict[str, str]) -> None: 298 | """Initialise fake pytest.Item object.""" 299 | self.nodeid = collected_item["nodeid"] 300 | self.name = collected_item["name"] 301 | self.path = Path(collected_item["path"]) 302 | self.obj = _FakePytestObject(collected_item) 303 | 304 | 305 | def _serialize_collections(collections: SnapshotCollections) -> dict[str, Any]: 306 | return { 307 | k: [c.name for c in v] for k, v in collections._snapshot_collections.items() 308 | } 309 | 310 | 311 | def _serialize_report( 312 | report: SnapshotReport, 313 | collected_items: set[pytest.Item], 314 | selected_items: dict[str, ItemStatus], 315 | ) -> dict[str, Any]: 316 | return { 317 | "discovered": _serialize_collections(report.discovered), 318 | "created": _serialize_collections(report.created), 319 | "failed": _serialize_collections(report.failed), 320 | "matched": _serialize_collections(report.matched), 321 | "updated": _serialize_collections(report.updated), 322 | "used": _serialize_collections(report.used), 323 | "_collected_items": [ 324 | { 325 | "nodeid": c.nodeid, 326 | "name": c.name, 327 | "path": str(c.path), 328 | "modulename": c.obj.__module__, 329 | "methodname": c.obj.__name__, 330 | } 331 | for c in list(collected_items) 332 | ], 333 | "_selected_items": { 334 | key: status.value for key, status in selected_items.items() 335 | }, 336 | } 337 | 338 | 339 | def _merge_serialized_collections( 340 | collections: SnapshotCollections, json_data: dict[str, list[str]] 341 | ) -> None: 342 | if not json_data: 343 | return 344 | for location, names in json_data.items(): 345 | snapshot_collection = SnapshotCollection(location=location) 346 | for name in names: 347 | snapshot_collection.add(Snapshot(name)) 348 | collections.update(snapshot_collection) 349 | 350 | 351 | def _merge_serialized_report(report: SnapshotReport, json_data: dict[str, Any]) -> None: 352 | _merge_serialized_collections(report.discovered, json_data["discovered"]) 353 | _merge_serialized_collections(report.created, json_data["created"]) 354 | _merge_serialized_collections(report.failed, json_data["failed"]) 355 | _merge_serialized_collections(report.matched, json_data["matched"]) 356 | _merge_serialized_collections(report.updated, json_data["updated"]) 357 | _merge_serialized_collections(report.used, json_data["used"]) 358 | for collected_item in json_data["_collected_items"]: 359 | custom_item = _FakePytestItem(collected_item) 360 | if not any( 361 | t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid 362 | for t in report.collected_items 363 | ): 364 | report.collected_items.add(custom_item) 365 | for key, selected_item in json_data["_selected_items"].items(): 366 | if key in report.selected_items: 367 | status = ItemStatus(selected_item) 368 | if status != ItemStatus.NOT_RUN: 369 | report.selected_items[key] = status 370 | else: 371 | report.selected_items[key] = ItemStatus(selected_item) 372 | 373 | 374 | def override_syrupy_finish(self: SnapshotSession) -> int: 375 | """Override the finish method to allow for custom handling.""" 376 | exitstatus = 0 377 | self.flush_snapshot_write_queue() 378 | self.report = SnapshotReport( 379 | base_dir=self.pytest_session.config.rootpath, 380 | collected_items=self._collected_items, 381 | selected_items=self._selected_items, 382 | assertions=self._assertions, 383 | options=self.pytest_session.config.option, 384 | ) 385 | 386 | needs_xdist_merge = self.update_snapshots or bool( 387 | self.pytest_session.config.option.include_snapshot_details 388 | ) 389 | 390 | if is_xdist_worker(): 391 | if not needs_xdist_merge: 392 | return exitstatus 393 | with open(".pytest_syrupy_worker_count", "w", encoding="utf-8") as f: 394 | f.write(os.getenv("PYTEST_XDIST_WORKER_COUNT")) 395 | with open( 396 | f".pytest_syrupy_{os.getenv('PYTEST_XDIST_WORKER')}_result", 397 | "w", 398 | encoding="utf-8", 399 | ) as f: 400 | json.dump( 401 | _serialize_report( 402 | self.report, self._collected_items, self._selected_items 403 | ), 404 | f, 405 | indent=2, 406 | ) 407 | return exitstatus 408 | if is_xdist_controller(): 409 | return exitstatus 410 | 411 | if needs_xdist_merge: 412 | worker_count = None 413 | try: 414 | with open(".pytest_syrupy_worker_count", encoding="utf-8") as f: 415 | worker_count = f.read() 416 | os.remove(".pytest_syrupy_worker_count") 417 | except FileNotFoundError: 418 | pass 419 | 420 | if worker_count: 421 | for i in range(int(worker_count)): 422 | with open(f".pytest_syrupy_gw{i}_result", encoding="utf-8") as f: 423 | _merge_serialized_report(self.report, json.load(f)) 424 | os.remove(f".pytest_syrupy_gw{i}_result") 425 | 426 | if self.report.num_unused: 427 | if self.update_snapshots: 428 | self.remove_unused_snapshots( 429 | unused_snapshot_collections=self.report.unused, 430 | used_snapshot_collections=self.report.used, 431 | ) 432 | elif not self.warn_unused_snapshots: 433 | exitstatus |= EXIT_STATUS_FAIL_UNUSED 434 | return exitstatus 435 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/test_util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test utilities. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from collections.abc import Awaitable, Callable 8 | 9 | from aiohttp.web import Application, Request, StreamResponse, middleware 10 | 11 | 12 | def mock_real_ip(app: Application) -> Callable[[str], None]: 13 | """Inject middleware to mock real IP. 14 | 15 | Returns a function to set the real IP. 16 | """ 17 | ip_to_mock: str | None = None 18 | 19 | def set_ip_to_mock(value: str): 20 | nonlocal ip_to_mock 21 | ip_to_mock = value 22 | 23 | @middleware 24 | async def mock_real_ip( 25 | request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] 26 | ) -> StreamResponse: 27 | """Mock Real IP middleware.""" 28 | nonlocal ip_to_mock 29 | 30 | request = request.clone(remote=ip_to_mock) 31 | 32 | return await handler(request) 33 | 34 | async def real_ip_startup(app): 35 | """Startup of real ip.""" 36 | app.middlewares.insert(0, mock_real_ip) 37 | 38 | app.on_startup.append(real_ip_startup) 39 | 40 | return set_ip_to_mock 41 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/test_util/aiohttp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Aiohttp test utils. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | import asyncio 8 | from collections.abc import Iterator 9 | from contextlib import contextmanager 10 | from http import HTTPStatus 11 | import re 12 | from types import TracebackType 13 | from typing import Any 14 | from unittest import mock 15 | from urllib.parse import parse_qs 16 | 17 | from aiohttp import ClientSession 18 | from aiohttp.client_exceptions import ( 19 | ClientConnectionError, 20 | ClientError, 21 | ClientResponseError, 22 | ) 23 | from aiohttp.streams import StreamReader 24 | from multidict import CIMultiDict 25 | from yarl import URL 26 | 27 | from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE 28 | from homeassistant.core import HomeAssistant 29 | from homeassistant.helpers.json import json_dumps 30 | from homeassistant.util.json import json_loads 31 | 32 | RETYPE = type(re.compile("")) 33 | 34 | 35 | def mock_stream(data): 36 | """Mock a stream with data.""" 37 | protocol = mock.Mock(_reading_paused=False) 38 | stream = StreamReader(protocol, limit=2**16) 39 | stream.feed_data(data) 40 | stream.feed_eof() 41 | return stream 42 | 43 | 44 | class AiohttpClientMocker: 45 | """Mock Aiohttp client requests.""" 46 | 47 | def __init__(self) -> None: 48 | """Initialize the request mocker.""" 49 | self._mocks = [] 50 | self._cookies = {} 51 | self.mock_calls = [] 52 | 53 | def request( 54 | self, 55 | method, 56 | url, 57 | *, 58 | auth=None, 59 | status=HTTPStatus.OK, 60 | text=None, 61 | data=None, 62 | content=None, 63 | json=None, 64 | params=None, 65 | headers=None, 66 | exc=None, 67 | cookies=None, 68 | side_effect=None, 69 | closing=None, 70 | ): 71 | """Mock a request.""" 72 | if not isinstance(url, RETYPE): 73 | url = URL(url) 74 | if params: 75 | url = url.with_query(params) 76 | 77 | self._mocks.append( 78 | AiohttpClientMockResponse( 79 | method=method, 80 | url=url, 81 | status=status, 82 | response=content, 83 | json=json, 84 | text=text, 85 | cookies=cookies, 86 | exc=exc, 87 | headers=headers, 88 | side_effect=side_effect, 89 | closing=closing, 90 | ) 91 | ) 92 | 93 | def get(self, *args, **kwargs): 94 | """Register a mock get request.""" 95 | self.request("get", *args, **kwargs) 96 | 97 | def put(self, *args, **kwargs): 98 | """Register a mock put request.""" 99 | self.request("put", *args, **kwargs) 100 | 101 | def post(self, *args, **kwargs): 102 | """Register a mock post request.""" 103 | self.request("post", *args, **kwargs) 104 | 105 | def delete(self, *args, **kwargs): 106 | """Register a mock delete request.""" 107 | self.request("delete", *args, **kwargs) 108 | 109 | def options(self, *args, **kwargs): 110 | """Register a mock options request.""" 111 | self.request("options", *args, **kwargs) 112 | 113 | def patch(self, *args, **kwargs): 114 | """Register a mock patch request.""" 115 | self.request("patch", *args, **kwargs) 116 | 117 | def head(self, *args, **kwargs): 118 | """Register a mock head request.""" 119 | self.request("head", *args, **kwargs) 120 | 121 | @property 122 | def call_count(self): 123 | """Return the number of requests made.""" 124 | return len(self.mock_calls) 125 | 126 | def clear_requests(self): 127 | """Reset mock calls.""" 128 | self._mocks.clear() 129 | self._cookies.clear() 130 | self.mock_calls.clear() 131 | 132 | def create_session(self, loop): 133 | """Create a ClientSession that is bound to this mocker.""" 134 | session = ClientSession(loop=loop, json_serialize=json_dumps) 135 | # Setting directly on `session` will raise deprecation warning 136 | object.__setattr__(session, "_request", self.match_request) 137 | return session 138 | 139 | async def match_request( 140 | self, 141 | method, 142 | url, 143 | *, 144 | data=None, 145 | auth=None, 146 | params=None, 147 | headers=None, 148 | allow_redirects=None, 149 | timeout=None, 150 | json=None, 151 | cookies=None, 152 | **kwargs, 153 | ): 154 | """Match a request against pre-registered requests.""" 155 | data = data or json 156 | url = URL(url) 157 | if params: 158 | url = url.with_query(params) 159 | 160 | for response in self._mocks: 161 | if response.match_request(method, url, params): 162 | self.mock_calls.append((method, url, data, headers)) 163 | if response.side_effect: 164 | response = await response.side_effect(method, url, data) 165 | if response.exc: 166 | raise response.exc 167 | return response 168 | 169 | raise AssertionError(f"No mock registered for {method.upper()} {url} {params}") 170 | 171 | 172 | class AiohttpClientMockResponse: 173 | """Mock Aiohttp client response.""" 174 | 175 | def __init__( 176 | self, 177 | method, 178 | url: URL, 179 | status=HTTPStatus.OK, 180 | response=None, 181 | json=None, 182 | text=None, 183 | cookies=None, 184 | exc=None, 185 | headers=None, 186 | side_effect=None, 187 | closing=None, 188 | ) -> None: 189 | """Initialize a fake response.""" 190 | if json is not None: 191 | text = json_dumps(json) 192 | if text is not None: 193 | response = text.encode("utf-8") 194 | if response is None: 195 | response = b"" 196 | 197 | self.charset = "utf-8" 198 | self.method = method 199 | self._url = url 200 | self.status = status 201 | self._response = response 202 | self.exc = exc 203 | self.side_effect = side_effect 204 | self.closing = closing 205 | self._headers = CIMultiDict(headers or {}) 206 | self._cookies = {} 207 | 208 | if cookies: 209 | for name, data in cookies.items(): 210 | cookie = mock.MagicMock() 211 | cookie.value = data 212 | self._cookies[name] = cookie 213 | 214 | def match_request(self, method, url, params=None): 215 | """Test if response answers request.""" 216 | if method.lower() != self.method.lower(): 217 | return False 218 | 219 | # regular expression matching 220 | if isinstance(self._url, RETYPE): 221 | return self._url.search(str(url)) is not None 222 | 223 | if ( 224 | self._url.scheme != url.scheme 225 | or self._url.host != url.host 226 | or self._url.path != url.path 227 | ): 228 | return False 229 | 230 | # Ensure all query components in matcher are present in the request 231 | request_qs = parse_qs(url.query_string) 232 | matcher_qs = parse_qs(self._url.query_string) 233 | for key, vals in matcher_qs.items(): 234 | for val in vals: 235 | try: 236 | request_qs.get(key, []).remove(val) 237 | except ValueError: 238 | return False 239 | 240 | return True 241 | 242 | @property 243 | def headers(self): 244 | """Return content_type.""" 245 | return self._headers 246 | 247 | @property 248 | def cookies(self): 249 | """Return dict of cookies.""" 250 | return self._cookies 251 | 252 | @property 253 | def url(self): 254 | """Return yarl of URL.""" 255 | return self._url 256 | 257 | @property 258 | def content_type(self): 259 | """Return yarl of URL.""" 260 | return self._headers.get("content-type") 261 | 262 | @property 263 | def content(self): 264 | """Return content.""" 265 | return mock_stream(self.response) 266 | 267 | async def read(self): 268 | """Return mock response.""" 269 | return self.response 270 | 271 | async def text(self, encoding="utf-8", errors="strict"): 272 | """Return mock response as a string.""" 273 | return self.response.decode(encoding, errors=errors) 274 | 275 | async def json(self, encoding="utf-8", content_type=None, loads=json_loads): 276 | """Return mock response as a json.""" 277 | return loads(self.response.decode(encoding)) 278 | 279 | def release(self): 280 | """Mock release.""" 281 | 282 | def raise_for_status(self): 283 | """Raise error if status is 400 or higher.""" 284 | if self.status >= 400: 285 | request_info = mock.Mock(real_url="http://example.com") 286 | raise ClientResponseError( 287 | request_info=request_info, 288 | history=None, 289 | status=self.status, 290 | headers=self.headers, 291 | ) 292 | 293 | def close(self): 294 | """Mock close.""" 295 | 296 | async def wait_for_close(self): 297 | """Wait until all requests are done. 298 | 299 | Do nothing as we are mocking. 300 | """ 301 | 302 | @property 303 | def response(self): 304 | """Property method to expose the response to other read methods.""" 305 | if self.closing: 306 | raise ClientConnectionError("Connection closed") 307 | return self._response 308 | 309 | async def __aenter__(self): 310 | """Enter the context manager.""" 311 | return self 312 | 313 | async def __aexit__( 314 | self, 315 | exc_type: type[BaseException] | None, 316 | exc_val: BaseException | None, 317 | exc_tb: TracebackType | None, 318 | ) -> None: 319 | """Exit the context manager.""" 320 | 321 | 322 | @contextmanager 323 | def mock_aiohttp_client() -> Iterator[AiohttpClientMocker]: 324 | """Context manager to mock aiohttp client.""" 325 | mocker = AiohttpClientMocker() 326 | 327 | def create_session(hass: HomeAssistant, *args: Any, **kwargs: Any) -> ClientSession: 328 | session = mocker.create_session(hass.loop) 329 | 330 | async def close_session(event): 331 | """Close session.""" 332 | await session.close() 333 | 334 | hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, close_session) 335 | 336 | return session 337 | 338 | with mock.patch( 339 | "homeassistant.helpers.aiohttp_client._async_create_clientsession", 340 | side_effect=create_session, 341 | ): 342 | yield mocker 343 | 344 | 345 | class MockLongPollSideEffect: 346 | """Imitate a long_poll request. 347 | 348 | It should be created and used as a side effect for a GET/PUT/etc. request. 349 | Once created, actual responses are queued with queue_response 350 | If queue is empty, will await until done. 351 | """ 352 | 353 | def __init__(self) -> None: 354 | """Initialize the queue.""" 355 | self.semaphore = asyncio.Semaphore(0) 356 | self.response_list = [] 357 | self.stopping = False 358 | 359 | async def __call__(self, method, url, data): 360 | """Fetch the next response from the queue or wait until the queue has items.""" 361 | if self.stopping: 362 | raise ClientError 363 | await self.semaphore.acquire() 364 | kwargs = self.response_list.pop(0) 365 | return AiohttpClientMockResponse(method=method, url=url, **kwargs) 366 | 367 | def queue_response(self, **kwargs): 368 | """Add a response to the long_poll queue.""" 369 | self.response_list.append(kwargs) 370 | self.semaphore.release() 371 | 372 | def stop(self): 373 | """Stop the current request and future ones. 374 | 375 | This avoids an exception if there is someone waiting when exiting test. 376 | """ 377 | self.stopping = True 378 | self.queue_response(exc=ClientError()) 379 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/testing_config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration that's used when running tests. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/testing_config/custom_components/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of custom integrations used when running tests. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/testing_config/custom_components/test_constant_deprecation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test deprecated constants custom integration. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from types import ModuleType 8 | from typing import Any 9 | 10 | 11 | def import_deprecated_constant(module: ModuleType, constant_name: str) -> Any: 12 | """Import and return deprecated constant.""" 13 | return getattr(module, constant_name) 14 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Typing helpers for Home Assistant . 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from collections.abc import Callable, Coroutine 10 | from contextlib import AbstractAsyncContextManager 11 | from typing import TYPE_CHECKING, Any 12 | from unittest.mock import MagicMock 13 | 14 | from aiohttp import ClientWebSocketResponse 15 | from aiohttp.test_utils import TestClient 16 | 17 | if TYPE_CHECKING: 18 | # Local import to avoid processing recorder module when running a 19 | # testcase which does not use the recorder. 20 | from homeassistant.components.recorder import Recorder 21 | 22 | 23 | class MockHAClientWebSocket(ClientWebSocketResponse): 24 | """Protocol for a wrapped ClientWebSocketResponse.""" 25 | 26 | client: TestClient 27 | send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]] 28 | remove_device: Callable[[str, str], Coroutine[Any, Any, Any]] 29 | 30 | 31 | type ClientSessionGenerator = Callable[..., Coroutine[Any, Any, TestClient]] 32 | type MqttMockPahoClient = MagicMock 33 | """MagicMock for `paho.mqtt.client.Client`""" 34 | type MqttMockHAClient = MagicMock 35 | """MagicMock for `homeassistant.components.mqtt.MQTT`.""" 36 | type MqttMockHAClientGenerator = Callable[..., Coroutine[Any, Any, MqttMockHAClient]] 37 | """MagicMock generator for `homeassistant.components.mqtt.MQTT`.""" 38 | type RecorderInstanceContextManager = Callable[ 39 | ..., AbstractAsyncContextManager[Recorder] 40 | ] 41 | """ContextManager for `homeassistant.components.recorder.Recorder`.""" 42 | type RecorderInstanceGenerator = Callable[..., Coroutine[Any, Any, Recorder]] 43 | """Instance generator for `homeassistant.components.recorder.Recorder`.""" 44 | type WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]] 45 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the Simple Integration integration.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for testing.""" 2 | import pytest 3 | 4 | from pytest_homeassistant_custom_component.syrupy import HomeAssistantSnapshotExtension 5 | from syrupy.assertion import SnapshotAssertion 6 | 7 | 8 | @pytest.fixture 9 | def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: 10 | """Return snapshot assertion fixture with the Home Assistant extension.""" 11 | return snapshot.use_extension(HomeAssistantSnapshotExtension) 12 | 13 | 14 | @pytest.fixture(autouse=True) 15 | def auto_enable_custom_integrations(enable_custom_integrations): 16 | yield 17 | 18 | -------------------------------------------------------------------------------- /tests/fixtures/test_array.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"test_key1": "test_value1"}, 3 | {"test_key2": "test_value2"} 4 | ] -------------------------------------------------------------------------------- /tests/fixtures/test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "test_key": "test_value" 3 | } 4 | -------------------------------------------------------------------------------- /tests/snapshots/test_diagnostics.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_entry_diagnostics 3 | dict({ 4 | 'config_entry': dict({ 5 | 'data': dict({ 6 | 'name': 'simple config', 7 | }), 8 | 'disabled_by': None, 9 | 'discovery_keys': dict({ 10 | }), 11 | 'domain': 'simple_integration', 12 | 'minor_version': 1, 13 | 'options': dict({ 14 | }), 15 | 'pref_disable_new_entities': False, 16 | 'pref_disable_polling': False, 17 | 'source': 'user', 18 | 'subentries': list([ 19 | ]), 20 | 'title': 'Mock Title', 21 | 'unique_id': None, 22 | 'version': 1, 23 | }), 24 | }) 25 | # --- 26 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | """Tests changes to common module.""" 2 | import json 3 | 4 | from pytest_homeassistant_custom_component.common import ( 5 | load_fixture, 6 | load_json_value_fixture, 7 | load_json_array_fixture, 8 | load_json_object_fixture 9 | ) 10 | 11 | def test_load_fixture(): 12 | data = json.loads(load_fixture("test_data.json")) 13 | assert data == {"test_key": "test_value"} 14 | 15 | def test_load_json_value_fixture(): 16 | """Test load_json_value_fixture can load fixture file""" 17 | data = load_json_value_fixture("test_data.json") 18 | assert data == {"test_key": "test_value"} 19 | 20 | def test_load_json_array_fixture(): 21 | """Test load_json_array_fixture can load fixture file""" 22 | data = load_json_array_fixture("test_array.json") 23 | assert data == [{"test_key1": "test_value1"},{"test_key2": "test_value2"}] 24 | 25 | def test_load_json_object_fixture(): 26 | """Test load_json_object_fixture can load fixture file""" 27 | data = load_json_object_fixture("test_data.json") 28 | assert data == {"test_key": "test_value"} 29 | -------------------------------------------------------------------------------- /tests/test_config_flow.py: -------------------------------------------------------------------------------- 1 | """Test the Simple Integration config flow.""" 2 | from unittest.mock import patch 3 | 4 | from homeassistant import config_entries, setup 5 | from custom_components.simple_integration.const import DOMAIN 6 | 7 | 8 | async def test_form(hass): 9 | """Test we get the form.""" 10 | await setup.async_setup_component(hass, "persistent_notification", {}) 11 | result = await hass.config_entries.flow.async_init( 12 | DOMAIN, context={"source": config_entries.SOURCE_USER} 13 | ) 14 | assert result["type"] == "form" 15 | assert result["errors"] == {} 16 | 17 | with patch( 18 | "custom_components.simple_integration.async_setup", return_value=True 19 | ) as mock_setup, patch( 20 | "custom_components.simple_integration.async_setup_entry", 21 | return_value=True, 22 | ) as mock_setup_entry: 23 | result2 = await hass.config_entries.flow.async_configure( 24 | result["flow_id"], 25 | { 26 | "name": "new_simple_config" 27 | }, 28 | ) 29 | 30 | assert result2["type"] == "create_entry" 31 | assert result2["title"] == "new_simple_config" 32 | assert result2["data"] == { 33 | "name": "new_simple_config", 34 | } 35 | await hass.async_block_till_done() 36 | assert len(mock_setup.mock_calls) == 1 37 | assert len(mock_setup_entry.mock_calls) == 1 38 | -------------------------------------------------------------------------------- /tests/test_diagnostics.py: -------------------------------------------------------------------------------- 1 | """Test the Simple Integration diagnostics.""" 2 | 3 | from syrupy.assertion import SnapshotAssertion 4 | 5 | from homeassistant.core import HomeAssistant 6 | 7 | from pytest_homeassistant_custom_component.common import MockConfigEntry 8 | from pytest_homeassistant_custom_component.components.diagnostics import get_diagnostics_for_config_entry 9 | from pytest_homeassistant_custom_component.typing import ClientSessionGenerator 10 | 11 | from custom_components.simple_integration.const import DOMAIN 12 | 13 | # Fields to exclude from snapshot as they change each run 14 | TO_EXCLUDE = { 15 | "id", 16 | "device_id", 17 | "via_device_id", 18 | "last_updated", 19 | "last_changed", 20 | "last_reported", 21 | "created_at", 22 | "modified_at", 23 | "entry_id", 24 | } 25 | 26 | 27 | def limit_diagnostic_attrs(prop, path) -> bool: 28 | """Mark attributes to exclude from diagnostic snapshot.""" 29 | return prop in TO_EXCLUDE 30 | 31 | 32 | async def test_entry_diagnostics( 33 | hass: HomeAssistant, 34 | hass_client: ClientSessionGenerator, 35 | snapshot: SnapshotAssertion, 36 | ) -> None: 37 | """Test config entry diagnostics.""" 38 | 39 | entry = MockConfigEntry(domain=DOMAIN, data={"name": "simple config",}) 40 | entry.add_to_hass(hass) 41 | await hass.config_entries.async_setup(entry.entry_id) 42 | await hass.async_block_till_done() 43 | 44 | assert await get_diagnostics_for_config_entry( 45 | hass, hass_client, entry 46 | ) == snapshot(exclude=limit_diagnostic_attrs) 47 | -------------------------------------------------------------------------------- /tests/test_sensor.py: -------------------------------------------------------------------------------- 1 | """Test sensor for simple integration.""" 2 | from pytest_homeassistant_custom_component.common import MockConfigEntry 3 | 4 | from custom_components.simple_integration.const import DOMAIN 5 | 6 | 7 | async def test_sensor(hass): 8 | """Test sensor.""" 9 | entry = MockConfigEntry(domain=DOMAIN, data={"name": "simple config",}) 10 | entry.add_to_hass(hass) 11 | await hass.config_entries.async_setup(entry.entry_id) 12 | await hass.async_block_till_done() 13 | 14 | state = hass.states.get("sensor.example_temperature") 15 | 16 | assert state 17 | assert state.state == "23" 18 | -------------------------------------------------------------------------------- /version: -------------------------------------------------------------------------------- 1 | 0.13.249 2 | --------------------------------------------------------------------------------