├── ha_version ├── version ├── generate_phacc ├── __init__.py ├── const.py ├── ha.py └── generate_phacc.py ├── custom_components ├── __init__.py └── simple_integration │ ├── const.py │ ├── manifest.json │ ├── strings.json │ ├── translations │ └── en.json │ ├── diagnostics.py │ ├── __init__.py │ ├── sensor.py │ └── config_flow.py ├── requirements_generate.txt ├── tests ├── __init__.py ├── fixtures │ ├── test_data.json │ └── test_array.json ├── conftest.py ├── test_sensor.py ├── snapshots │ └── test_diagnostics.ambr ├── test_common.py ├── test_config_flow.py └── test_diagnostics.py ├── setup.cfg ├── .devcontainer ├── devcontainer.json └── podman │ └── devcontainer.json ├── .gitpod.Dockerfile ├── .github ├── dependabot.yml └── workflows │ ├── publish.yml │ ├── generate_package.yml │ ├── pytest.yml │ ├── make_release.yml │ └── automatic_generation.yml ├── src └── pytest_homeassistant_custom_component │ ├── __init__.py │ ├── testing_config │ ├── __init__.py │ └── custom_components │ │ ├── __init__.py │ │ └── test_constant_deprecation │ │ └── __init__.py │ ├── components │ ├── recorder │ │ ├── __init__.py │ │ ├── db_schema_0.py │ │ └── common.py │ ├── diagnostics │ │ └── __init__.py │ └── __init__.py │ ├── const.py │ ├── patch_recorder.py │ ├── test_util │ ├── __init__.py │ └── aiohttp.py │ ├── patch_json.py │ ├── ignore_uncaught_exceptions.py │ ├── typing.py │ ├── patch_time.py │ ├── asyncio_legacy.py │ └── syrupy.py ├── .gitpod.yml ├── requirements_dev.txt ├── LICENSE ├── requirements_test.txt ├── CHANGELOG.md ├── setup.py ├── .gitignore ├── README.md └── LICENSE_HA_CORE.md /ha_version: -------------------------------------------------------------------------------- 1 | 2025.12.2 -------------------------------------------------------------------------------- /version: -------------------------------------------------------------------------------- 1 | 0.13.300 2 | -------------------------------------------------------------------------------- /generate_phacc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /custom_components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements_generate.txt: -------------------------------------------------------------------------------- 1 | click==8.1.3 2 | GitPython==3.1.14 3 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the Simple Integration integration.""" 2 | -------------------------------------------------------------------------------- /tests/fixtures/test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "test_key": "test_value" 3 | } 4 | -------------------------------------------------------------------------------- /tests/fixtures/test_array.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"test_key1": "test_value1"}, 3 | {"test_key2": "test_value2"} 4 | ] -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = tests 3 | asyncio_mode = auto 4 | asyncio_default_fixture_loop_scope = function -------------------------------------------------------------------------------- /custom_components/simple_integration/const.py: -------------------------------------------------------------------------------- 1 | """Constants for the Simple Integration integration.""" 2 | 3 | DOMAIN = "simple_integration" 4 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/python:3.13", 3 | "postCreateCommand": "pip3 install -r requirements_generate.txt" 4 | } 5 | -------------------------------------------------------------------------------- /.gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-python-3.11 2 | 3 | USER gitpod 4 | COPY requirements_generate.txt requirements_generate.txt 5 | RUN pip install -r requirements_generate.txt 6 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Set update schedule for GitHub Actions 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for Home Assistant. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/testing_config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration that's used when running tests. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /custom_components/simple_integration/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "domain": "simple_integration", 3 | "name": "Simple Integration", 4 | "config_flow": true, 5 | "documentation": "NoWhere", 6 | "version": "0.0.0", 7 | "codeowners": [ 8 | "@NoOne" 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/testing_config/custom_components/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of custom integrations used when running tests. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/recorder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for Recorder component. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | import pytest 8 | 9 | pytest.register_assert_rewrite("tests.components.recorder.common") 10 | -------------------------------------------------------------------------------- /custom_components/simple_integration/strings.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Simple Integration", 3 | "config": { 4 | "step": { 5 | "user": { 6 | "data": { 7 | "name": "Name" 8 | } 9 | } 10 | }, 11 | "error": { 12 | "unknown": "[%key:common::config_flow::error::unknown%]" 13 | }, 14 | "abort": { 15 | "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /custom_components/simple_integration/translations/en.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Simple Integration", 3 | "config": { 4 | "step": { 5 | "user": { 6 | "data": { 7 | "name": "Name" 8 | } 9 | } 10 | }, 11 | "error": { 12 | "unknown": "[%key:common::config_flow::error::unknown%]" 13 | }, 14 | "abort": { 15 | "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /.devcontainer/podman/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/python:3.13", 3 | "postCreateCommand": "pip install --upgrade pip && pip3 install -r requirements_generate.txt", 4 | // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 5 | "runArgs": [ 6 | "--userns=keep-id" 7 | ], 8 | "containerUser": "vscode", 9 | "updateRemoteUserUID": true, 10 | "containerEnv": { 11 | "HOME": "/home/vscode" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants used by Home Assistant components. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | from typing import TYPE_CHECKING, Final 7 | MAJOR_VERSION: Final = 2025 8 | MINOR_VERSION: Final = 12 9 | PATCH_VERSION: Final = "2" 10 | __short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}" 11 | __version__: Final = f"{__short_version__}.{PATCH_VERSION}" 12 | CONF_API_VERSION: Final = "api_version" 13 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/testing_config/custom_components/test_constant_deprecation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test deprecated constants custom integration. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from types import ModuleType 8 | from typing import Any 9 | 10 | 11 | def import_deprecated_constant(module: ModuleType, constant_name: str) -> Any: 12 | """Import and return deprecated constant.""" 13 | return getattr(module, constant_name) 14 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for testing.""" 2 | import pytest 3 | 4 | from pytest_homeassistant_custom_component.syrupy import HomeAssistantSnapshotExtension 5 | from syrupy.assertion import SnapshotAssertion 6 | 7 | 8 | @pytest.fixture 9 | def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: 10 | """Return snapshot assertion fixture with the Home Assistant extension.""" 11 | return snapshot.use_extension(HomeAssistantSnapshotExtension) 12 | 13 | 14 | @pytest.fixture(autouse=True) 15 | def auto_enable_custom_integrations(enable_custom_integrations): 16 | yield 17 | 18 | -------------------------------------------------------------------------------- /tests/test_sensor.py: -------------------------------------------------------------------------------- 1 | """Test sensor for simple integration.""" 2 | from pytest_homeassistant_custom_component.common import MockConfigEntry 3 | 4 | from custom_components.simple_integration.const import DOMAIN 5 | 6 | 7 | async def test_sensor(hass): 8 | """Test sensor.""" 9 | entry = MockConfigEntry(domain=DOMAIN, data={"name": "simple config",}) 10 | entry.add_to_hass(hass) 11 | await hass.config_entries.async_setup(entry.entry_id) 12 | await hass.async_block_till_done() 13 | 14 | state = hass.states.get("sensor.example_temperature") 15 | 16 | assert state 17 | assert state.state == "23" 18 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | # This configuration file was automatically generated by Gitpod. 2 | # Please adjust to your needs (see https://www.gitpod.io/docs/config-gitpod-file) 3 | # and commit this file to your remote git repository to share the goodness with others. 4 | image: 5 | file: .gitpod.Dockerfile 6 | tasks: 7 | - before: printf 'export PATH="%s:$PATH"\n' "/workspace/pytest-homeassistant-custom-component" >> $HOME/.bashrc && exit 8 | github: 9 | prebuilds: 10 | master: true 11 | branches: false 12 | pullRequests: false 13 | pullRequestsFromForks: false 14 | addCheck: false 15 | addComment: true 16 | addBadge: false 17 | -------------------------------------------------------------------------------- /tests/snapshots/test_diagnostics.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_entry_diagnostics 3 | dict({ 4 | 'config_entry': dict({ 5 | 'data': dict({ 6 | 'name': 'simple config', 7 | }), 8 | 'disabled_by': None, 9 | 'discovery_keys': dict({ 10 | }), 11 | 'domain': 'simple_integration', 12 | 'minor_version': 1, 13 | 'options': dict({ 14 | }), 15 | 'pref_disable_new_entities': False, 16 | 'pref_disable_polling': False, 17 | 'source': 'user', 18 | 'subentries': list([ 19 | ]), 20 | 'title': 'Mock Title', 21 | 'unique_id': None, 22 | 'version': 1, 23 | }), 24 | }) 25 | # --- 26 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: 5 | jobs: 6 | make_release: 7 | runs-on: "ubuntu-latest" 8 | steps: 9 | - uses: actions/checkout@v5 10 | - name: Set up Python 11 | uses: actions/setup-python@v6 12 | with: 13 | python-version: '3.13' 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install setuptools wheel twine 18 | - name: Build 19 | run: | 20 | python setup.py sdist bdist_wheel 21 | - name: Publish distribution 📦 to PyPI 22 | uses: pypa/gh-action-pypi-publish@release/v1 23 | with: 24 | password: ${{ secrets.PYPI_TOKEN }} 25 | -------------------------------------------------------------------------------- /custom_components/simple_integration/diagnostics.py: -------------------------------------------------------------------------------- 1 | """diagnostics for Simple Integration integration.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from homeassistant.components.diagnostics import async_redact_data 8 | from homeassistant.config_entries import ConfigEntry 9 | 10 | from homeassistant.core import HomeAssistant 11 | 12 | TO_REDACT = {} 13 | 14 | 15 | async def async_get_config_entry_diagnostics( 16 | hass: HomeAssistant, entry: ConfigEntry 17 | ) -> dict[str, Any]: 18 | """Return diagnostics for a config entry.""" 19 | 20 | diagnostic_data: dict[str, Any] = { 21 | "config_entry": async_redact_data(entry.as_dict(), TO_REDACT), 22 | } 23 | 24 | return diagnostic_data 25 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | # This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 2 | astroid==4.0.1 3 | librt==0.2.1 4 | mypy-dev==1.19.0a4 5 | pre-commit==4.2.0 6 | pylint==4.0.1 7 | types-aiofiles==24.1.0.20250822 8 | types-atomicwrites==1.4.5.1 9 | types-croniter==6.0.0.20250809 10 | types-caldav==1.3.0.20250516 11 | types-chardet==0.1.5 12 | types-decorator==5.2.0.20250324 13 | types-pexpect==4.9.0.20250916 14 | types-protobuf==6.30.2.20250914 15 | types-psutil==7.0.0.20251001 16 | types-pyserial==3.5.0.20251001 17 | types-python-dateutil==2.9.0.20250822 18 | types-python-slugify==8.0.2.20240310 19 | types-pytz==2025.2.0.20250809 20 | types-PyYAML==6.0.12.20250915 21 | types-requests==2.32.4.20250913 22 | types-xmltodict==1.0.1.20250920 23 | -------------------------------------------------------------------------------- /custom_components/simple_integration/__init__.py: -------------------------------------------------------------------------------- 1 | """The Simple Integration integration.""" 2 | 3 | from homeassistant.config_entries import ConfigEntry 4 | from homeassistant.core import HomeAssistant 5 | 6 | 7 | PLATFORMS = ["sensor"] 8 | 9 | 10 | async def async_setup(hass: HomeAssistant, config: dict): 11 | """Set up the Simple Integration component.""" 12 | return True 13 | 14 | 15 | async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): 16 | """Set up Simple Integration from a config entry.""" 17 | 18 | await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) 19 | 20 | return True 21 | 22 | 23 | async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): 24 | """Unload a config entry.""" 25 | 26 | unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) 27 | 28 | return unload_ok 29 | -------------------------------------------------------------------------------- /.github/workflows/generate_package.yml: -------------------------------------------------------------------------------- 1 | name: Generate Package 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | generate_package: 8 | runs-on: "ubuntu-latest" 9 | steps: 10 | - uses: "actions/checkout@v5" 11 | - name: checkout repo content 12 | uses: actions/checkout@v5 13 | - name: setup python 14 | uses: actions/setup-python@v6 15 | with: 16 | python-version: '3.13' 17 | - name: install dependencies 18 | run: pip install -r requirements_generate.txt 19 | - name: Install phacc for current versions 20 | run: pip install -e . 21 | - name: execute generate package 22 | run: | 23 | export PYTHONPATH=$PYTHONPATH:$(pwd) 24 | python generate_phacc/generate_phacc.py --regen 25 | - name: Create Pull Request 26 | uses: peter-evans/create-pull-request@v7 27 | with: 28 | token: ${{ secrets.REPO_SCOPED_TOKEN }} 29 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/patch_recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch recorder related functions. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from contextlib import contextmanager 10 | import sys 11 | 12 | # Patch recorder util session scope 13 | from homeassistant.helpers import recorder as recorder_helper 14 | 15 | # Make sure homeassistant.components.recorder.util is not already imported 16 | assert "homeassistant.components.recorder.util" not in sys.modules 17 | 18 | real_session_scope = recorder_helper.session_scope 19 | 20 | 21 | @contextmanager 22 | def _session_scope_wrapper(*args, **kwargs): 23 | """Make session_scope patchable. 24 | 25 | This function will be imported by recorder modules. 26 | """ 27 | with real_session_scope(*args, **kwargs) as ses: 28 | yield ses 29 | 30 | 31 | recorder_helper.session_scope = _session_scope_wrapper 32 | -------------------------------------------------------------------------------- /custom_components/simple_integration/sensor.py: -------------------------------------------------------------------------------- 1 | """Platform for sensor integration.""" 2 | from homeassistant.const import UnitOfTemperature 3 | from homeassistant.helpers.entity import Entity 4 | 5 | 6 | async def async_setup_entry(hass, config_entry, async_add_devices): 7 | """Set up entry.""" 8 | async_add_devices([ExampleSensor(),]) 9 | 10 | 11 | class ExampleSensor(Entity): 12 | """Representation of a Sensor.""" 13 | 14 | def __init__(self): 15 | """Initialize the sensor.""" 16 | self._state = 23 17 | 18 | @property 19 | def should_poll(self): 20 | """Whether entity polls.""" 21 | return False 22 | 23 | @property 24 | def name(self): 25 | """Return the name of the sensor.""" 26 | return 'Example Temperature' 27 | 28 | @property 29 | def state(self): 30 | """Return the state of the sensor.""" 31 | return self._state 32 | 33 | @property 34 | def unit_of_measurement(self): 35 | """Return the unit of measurement.""" 36 | return UnitOfTemperature.CELSIUS 37 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | """Tests changes to common module.""" 2 | import json 3 | 4 | from pytest_homeassistant_custom_component.common import ( 5 | load_fixture, 6 | load_json_value_fixture, 7 | load_json_array_fixture, 8 | load_json_object_fixture 9 | ) 10 | 11 | def test_load_fixture(): 12 | data = json.loads(load_fixture("test_data.json")) 13 | assert data == {"test_key": "test_value"} 14 | 15 | def test_load_json_value_fixture(): 16 | """Test load_json_value_fixture can load fixture file""" 17 | data = load_json_value_fixture("test_data.json") 18 | assert data == {"test_key": "test_value"} 19 | 20 | def test_load_json_array_fixture(): 21 | """Test load_json_array_fixture can load fixture file""" 22 | data = load_json_array_fixture("test_array.json") 23 | assert data == [{"test_key1": "test_value1"},{"test_key2": "test_value2"}] 24 | 25 | def test_load_json_object_fixture(): 26 | """Test load_json_object_fixture can load fixture file""" 27 | data = load_json_object_fixture("test_data.json") 28 | assert data == {"test_key": "test_value"} 29 | -------------------------------------------------------------------------------- /generate_phacc/const.py: -------------------------------------------------------------------------------- 1 | """Constants for phacc.""" 2 | TMP_DIR = "tmp_dir" 3 | PACKAGE_DIR = "src/pytest_homeassistant_custom_component" 4 | REQUIREMENTS_FILE = "requirements_test.txt" 5 | CONST_FILE = "const.py" 6 | 7 | REQUIREMENTS_FILE_DEV = "requirements_dev.txt" 8 | 9 | path = "." 10 | clone = "git clone https://github.com/home-assistant/core.git tmp_dir" 11 | diff = "git diff --exit-code" 12 | 13 | files = [ 14 | "__init__.py", 15 | "common.py", 16 | "conftest.py", 17 | "ignore_uncaught_exceptions.py", 18 | "components/__init__.py", 19 | "components/recorder/common.py", 20 | "patch_time.py", 21 | "syrupy.py", 22 | "typing.py", 23 | "patch_json.py", 24 | "patch_recorder.py", 25 | ] 26 | 27 | # remove requirements for development only, i.e not related to homeassistant tests 28 | requirements_remove = [ 29 | "codecov", 30 | "librt", 31 | "mypy", 32 | "mypy-dev", 33 | "pre-commit", 34 | "pylint", 35 | "astroid", 36 | ] 37 | 38 | LICENSE_FILE_HA = "LICENSE.md" 39 | LICENSE_FILE_NEW = "LICENSE_HA_CORE.md" 40 | 41 | HA_VERSION_FILE = "ha_version" 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Matthew Flamm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /custom_components/simple_integration/config_flow.py: -------------------------------------------------------------------------------- 1 | """Config flow for Simple Integration integration.""" 2 | import logging 3 | 4 | import voluptuous as vol 5 | 6 | from homeassistant import config_entries, core, exceptions 7 | 8 | from .const import DOMAIN # pylint:disable=unused-import 9 | 10 | _LOGGER = logging.getLogger(__name__) 11 | 12 | DATA_SCHEMA = vol.Schema({"name": str,}) 13 | 14 | 15 | class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): 16 | """Handle a config flow for Simple Integration.""" 17 | 18 | VERSION = 1 19 | 20 | CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL 21 | 22 | async def async_step_user(self, user_input=None): 23 | """Handle the initial step.""" 24 | errors = {} 25 | if user_input is not None: 26 | try: 27 | return self.async_create_entry(title=user_input["name"], data=user_input) 28 | except Exception: # pylint: disable=broad-except 29 | _LOGGER.exception("Unexpected exception") 30 | errors["base"] = "unknown" 31 | 32 | return self.async_show_form( 33 | step_id="user", data_schema=DATA_SCHEMA, errors=errors 34 | ) 35 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/test_util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test utilities. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from collections.abc import Awaitable, Callable 8 | 9 | from aiohttp.web import Application, Request, StreamResponse, middleware 10 | 11 | 12 | def mock_real_ip(app: Application) -> Callable[[str], None]: 13 | """Inject middleware to mock real IP. 14 | 15 | Returns a function to set the real IP. 16 | """ 17 | ip_to_mock: str | None = None 18 | 19 | def set_ip_to_mock(value: str): 20 | nonlocal ip_to_mock 21 | ip_to_mock = value 22 | 23 | @middleware 24 | async def mock_real_ip( 25 | request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] 26 | ) -> StreamResponse: 27 | """Mock Real IP middleware.""" 28 | nonlocal ip_to_mock 29 | 30 | request = request.clone(remote=ip_to_mock) 31 | 32 | return await handler(request) 33 | 34 | async def real_ip_startup(app): 35 | """Startup of real ip.""" 36 | app.middlewares.insert(0, mock_real_ip) 37 | 38 | app.on_startup.append(real_ip_startup) 39 | 40 | return set_ip_to_mock 41 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | # This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 2 | # linters such as pylint should be pinned, as new releases 3 | # make new things fail. Manually update these pins when pulling in a 4 | # new version 5 | 6 | # types-* that have versions roughly corresponding to the packages they 7 | # contain hints for available should be kept in sync with them 8 | 9 | -c homeassistant/package_constraints.txt 10 | -r requirements_test_pre_commit.txt 11 | coverage==7.10.6 12 | freezegun==1.5.2 13 | go2rtc-client==0.3.0 14 | # librt is an internal mypy dependency 15 | license-expression==30.4.3 16 | mock-open==1.4.0 17 | pydantic==2.12.2 18 | pylint-per-file-ignores==1.4.0 19 | pipdeptree==2.26.1 20 | pytest-asyncio==1.3.0 21 | pytest-aiohttp==1.1.0 22 | pytest-cov==7.0.0 23 | pytest-freezer==0.4.9 24 | pytest-github-actions-annotate-failures==0.3.0 25 | pytest-socket==0.7.0 26 | pytest-sugar==1.0.0 27 | pytest-timeout==2.4.0 28 | pytest-unordered==0.7.0 29 | pytest-picked==0.5.1 30 | pytest-xdist==3.8.0 31 | pytest==9.0.0 32 | requests-mock==1.12.1 33 | respx==0.22.0 34 | syrupy==5.0.0 35 | tqdm==4.67.1 36 | homeassistant==2025.12.2 37 | SQLAlchemy==2.0.41 38 | 39 | paho-mqtt==2.1.0 40 | 41 | numpy==2.3.2 42 | 43 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/patch_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch JSON related functions. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import functools 10 | from typing import Any 11 | from unittest import mock 12 | 13 | import orjson 14 | 15 | from homeassistant.helpers import json as json_helper 16 | 17 | real_json_encoder_default = json_helper.json_encoder_default 18 | 19 | mock_objects = [] 20 | 21 | 22 | def json_encoder_default(obj: Any) -> Any: 23 | """Convert Home Assistant objects. 24 | 25 | Hand other objects to the original method. 26 | """ 27 | if isinstance(obj, mock.Base): 28 | mock_objects.append(obj) 29 | raise TypeError(f"Attempting to serialize mock object {obj}") 30 | return real_json_encoder_default(obj) 31 | 32 | 33 | json_helper.json_encoder_default = json_encoder_default 34 | json_helper.json_bytes = functools.partial( 35 | orjson.dumps, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default 36 | ) 37 | json_helper.json_bytes_sorted = functools.partial( 38 | orjson.dumps, 39 | option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS, 40 | default=json_encoder_default, 41 | ) 42 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | This changelog only includes changes directly related to the structure of this project. Changes in testing behavior may still occur from changes in homeassistant/core. 3 | 4 | Changes to minor version indicate a change structurally in this pacakge. Changes in patch indicate changes solely from homeassistant/core. The latter does not imply no breaking changes are introduced. 5 | 6 | ## 0.13.0 7 | * bump minimum Python version to Python 3.10 8 | 9 | ## 0.8.0 10 | * recorder dependencies required for tests 11 | 12 | ## 0.7.0 13 | * paho-mqtt now required for tests 14 | 15 | ## 0.6.0 16 | * Python 3.8 dropped with homeassistant requirement 17 | * Minor change to generation of package for new homassistant code 18 | 19 | ## 0.4.0 20 | * `enable_custom_integrations` now required by ha 21 | * sqlalchemy version now pinned to ha version 22 | 23 | ## 0.3.0 24 | * Generate package only on homeassistant release versions 25 | * Use latest homeassistant release version including beta 26 | * homeassistant/core tags are used to determine latest release 27 | * Pin homeassistant version in requirements 28 | 29 | ## 0.2.0 30 | * fix `load_fixture` 31 | 32 | ## 0.1.0 33 | * remove Python 3.7 and add Python 3.9 34 | * remove `async_test` 35 | * move non-testing dependencies to separate `requirements_dev.txt` 36 | -------------------------------------------------------------------------------- /tests/test_config_flow.py: -------------------------------------------------------------------------------- 1 | """Test the Simple Integration config flow.""" 2 | from unittest.mock import patch 3 | 4 | from homeassistant import config_entries, setup 5 | from custom_components.simple_integration.const import DOMAIN 6 | 7 | 8 | async def test_form(hass): 9 | """Test we get the form.""" 10 | await setup.async_setup_component(hass, "persistent_notification", {}) 11 | result = await hass.config_entries.flow.async_init( 12 | DOMAIN, context={"source": config_entries.SOURCE_USER} 13 | ) 14 | assert result["type"] == "form" 15 | assert result["errors"] == {} 16 | 17 | with patch( 18 | "custom_components.simple_integration.async_setup", return_value=True 19 | ) as mock_setup, patch( 20 | "custom_components.simple_integration.async_setup_entry", 21 | return_value=True, 22 | ) as mock_setup_entry: 23 | result2 = await hass.config_entries.flow.async_configure( 24 | result["flow_id"], 25 | { 26 | "name": "new_simple_config" 27 | }, 28 | ) 29 | 30 | assert result2["type"] == "create_entry" 31 | assert result2["title"] == "new_simple_config" 32 | assert result2["data"] == { 33 | "name": "new_simple_config", 34 | } 35 | await hass.async_block_till_done() 36 | assert len(mock_setup.mock_calls) == 1 37 | assert len(mock_setup_entry.mock_calls) == 1 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from setuptools import setup, find_packages 4 | 5 | requirements = [ 6 | "sqlalchemy", 7 | ] 8 | with open("requirements_test.txt","r") as f: 9 | for line in f: 10 | if "txt" not in line and "#" not in line: 11 | requirements.append(line) 12 | 13 | with open("version", "r") as f: 14 | __version__ = f.read() 15 | 16 | setup( 17 | author="Matthew Flamm", 18 | name="pytest-homeassistant-custom-component", 19 | version=__version__, 20 | packages=find_packages(where="src"), 21 | package_dir={"": "src"}, 22 | python_requires=">=3.13", 23 | install_requires=requirements, 24 | license="MIT license", 25 | url="https://github.com/MatthewFlamm/pytest-homeassistant-custom-component", 26 | author_email="matthewflamm0@gmail.com", 27 | description="Experimental package to automatically extract test plugins for Home Assistant custom components", 28 | long_description=open('README.md').read(), 29 | long_description_content_type='text/markdown', 30 | classifiers=[ 31 | "Development Status :: 3 - Alpha", 32 | "Framework :: Pytest", 33 | "Intended Audience :: Developers", 34 | "License :: OSI Approved :: MIT License", 35 | "Programming Language :: Python", 36 | "Programming Language :: Python :: 3.13", 37 | "Topic :: Software Development :: Testing", 38 | ], 39 | entry_points={"pytest11": ["homeassistant = pytest_homeassistant_custom_component.plugins"]}, 40 | ) 41 | -------------------------------------------------------------------------------- /tests/test_diagnostics.py: -------------------------------------------------------------------------------- 1 | """Test the Simple Integration diagnostics.""" 2 | 3 | from syrupy.assertion import SnapshotAssertion 4 | 5 | from homeassistant.core import HomeAssistant 6 | 7 | from pytest_homeassistant_custom_component.common import MockConfigEntry 8 | from pytest_homeassistant_custom_component.components.diagnostics import get_diagnostics_for_config_entry 9 | from pytest_homeassistant_custom_component.typing import ClientSessionGenerator 10 | 11 | from custom_components.simple_integration.const import DOMAIN 12 | 13 | # Fields to exclude from snapshot as they change each run 14 | TO_EXCLUDE = { 15 | "id", 16 | "device_id", 17 | "via_device_id", 18 | "last_updated", 19 | "last_changed", 20 | "last_reported", 21 | "created_at", 22 | "modified_at", 23 | "entry_id", 24 | } 25 | 26 | 27 | def limit_diagnostic_attrs(prop, path) -> bool: 28 | """Mark attributes to exclude from diagnostic snapshot.""" 29 | return prop in TO_EXCLUDE 30 | 31 | 32 | async def test_entry_diagnostics( 33 | hass: HomeAssistant, 34 | hass_client: ClientSessionGenerator, 35 | snapshot: SnapshotAssertion, 36 | ) -> None: 37 | """Test config entry diagnostics.""" 38 | 39 | entry = MockConfigEntry(domain=DOMAIN, data={"name": "simple config",}) 40 | entry.add_to_hass(hass) 41 | await hass.config_entries.async_setup(entry.entry_id) 42 | await hass.async_block_till_done() 43 | 44 | assert await get_diagnostics_for_config_entry( 45 | hass, hass_client, entry 46 | ) == snapshot(exclude=limit_diagnostic_attrs) 47 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Pytest 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | schedule: 13 | - cron: "0 5 * * *" 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | python-version: ["3.13"] 23 | 24 | steps: 25 | - uses: actions/checkout@v5 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v6 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies generate 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements_generate.txt 34 | - name: Install phacc for current versions 35 | run: pip install -e . 36 | - name: execute generate package 37 | run: | 38 | export PYTHONPATH=$PYTHONPATH:$(pwd) 39 | python generate_phacc/generate_phacc.py --regen 40 | - name: list files 41 | run: ls -a 42 | - name: publish artifact 43 | uses: actions/upload-artifact@v4 44 | with: 45 | name: generated-package-${{ matrix.python-version }} 46 | path: | 47 | ./ 48 | !**/*.pyc 49 | !tmp_dir/ 50 | !.git/ 51 | if-no-files-found: error 52 | - name: Install dependencies test 53 | run: | 54 | pip install -e . 55 | - name: Test with pytest 56 | run: | 57 | pytest 58 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/ignore_uncaught_exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of tests that have uncaught exceptions today. Will be shrunk over time. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | IGNORE_UNCAUGHT_EXCEPTIONS = [ 8 | ( 9 | # This test explicitly throws an uncaught exception 10 | # and should not be removed. 11 | ".test_runner", 12 | "test_unhandled_exception_traceback", 13 | ), 14 | ( 15 | # This test explicitly throws an uncaught exception 16 | # and should not be removed. 17 | ".helpers.test_event", 18 | "test_track_point_in_time_repr", 19 | ), 20 | ( 21 | # This test explicitly throws an uncaught exception 22 | # and should not be removed. 23 | ".test_config_entries", 24 | "test_config_entry_unloaded_during_platform_setups", 25 | ), 26 | ( 27 | # This test explicitly throws an uncaught exception 28 | # and should not be removed. 29 | ".test_config_entries", 30 | "test_config_entry_unloaded_during_platform_setup", 31 | ), 32 | ( 33 | "test_homeassistant_bridge", 34 | "test_homeassistant_bridge_fan_setup", 35 | ), 36 | ( 37 | ".components.owntracks.test_device_tracker", 38 | "test_mobile_multiple_async_enter_exit", 39 | ), 40 | ( 41 | ".components.smartthings.test_init", 42 | "test_event_handler_dispatches_updated_devices", 43 | ), 44 | ( 45 | ".components.unifi.test_controller", 46 | "test_wireless_client_event_calls_update_wireless_devices", 47 | ), 48 | (".components.iaqualink.test_config_flow", "test_with_invalid_credentials"), 49 | (".components.iaqualink.test_config_flow", "test_with_existing_config"), 50 | ] 51 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Typing helpers for Home Assistant . 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from collections.abc import Callable, Coroutine 10 | from contextlib import AbstractAsyncContextManager 11 | from typing import TYPE_CHECKING, Any 12 | from unittest.mock import MagicMock 13 | 14 | from aiohttp import ClientWebSocketResponse 15 | from aiohttp.test_utils import TestClient 16 | 17 | if TYPE_CHECKING: 18 | # Local import to avoid processing recorder module when running a 19 | # testcase which does not use the recorder. 20 | from homeassistant.components.recorder import Recorder 21 | 22 | 23 | class MockHAClientWebSocket(ClientWebSocketResponse): 24 | """Protocol for a wrapped ClientWebSocketResponse.""" 25 | 26 | client: TestClient 27 | send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]] 28 | remove_device: Callable[[str, str], Coroutine[Any, Any, Any]] 29 | 30 | 31 | type ClientSessionGenerator = Callable[..., Coroutine[Any, Any, TestClient]] 32 | type MqttMockPahoClient = MagicMock 33 | """MagicMock for `paho.mqtt.client.Client`""" 34 | type MqttMockHAClient = MagicMock 35 | """MagicMock for `homeassistant.components.mqtt.MQTT`.""" 36 | type MqttMockHAClientGenerator = Callable[..., Coroutine[Any, Any, MqttMockHAClient]] 37 | """MagicMock generator for `homeassistant.components.mqtt.MQTT`.""" 38 | type RecorderInstanceContextManager = Callable[ 39 | ..., AbstractAsyncContextManager[Recorder] 40 | ] 41 | """ContextManager for `homeassistant.components.recorder.Recorder`.""" 42 | type RecorderInstanceGenerator = Callable[..., Coroutine[Any, Any, Recorder]] 43 | """Instance generator for `homeassistant.components.recorder.Recorder`.""" 44 | type WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]] 45 | -------------------------------------------------------------------------------- /.github/workflows/make_release.yml: -------------------------------------------------------------------------------- 1 | name: Make Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | type: 7 | description: 'Type of version to increment. major, minor, or patch.' 8 | required: true 9 | default: 'patch' 10 | jobs: 11 | make_release: 12 | runs-on: "ubuntu-latest" 13 | steps: 14 | - uses: actions/checkout@v5 15 | - name: store current ha version 16 | id: current-ha-version 17 | run: echo "::set-output name=current-ha-version::$(cat ha_version)" 18 | - id: next_version 19 | uses: zwaldowski/semver-release-action@v4 20 | with: 21 | dry_run: true 22 | bump: ${{ github.event.inputs.type }} 23 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 24 | - run: echo "${{ steps.next_version.outputs.version }}" > version 25 | - run: echo "${{ steps.next_version.outputs.version }}" 26 | - id: git_commit 27 | run: | 28 | git config user.name 'Matthew Flamm' 29 | git config user.email 'MatthewFlamm@users.noreply.github.com' 30 | git add . 31 | git commit -m "Bump version" 32 | git push 33 | echo "::set-output name=sha::$(git rev-parse HEAD)" 34 | - uses: zwaldowski/semver-release-action@v4 35 | with: 36 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 37 | sha: ${{ steps.git_commit.outputs.sha }} 38 | - name: Create Release 39 | id: create_release 40 | uses: actions/create-release@v1 41 | env: 42 | GITHUB_TOKEN: ${{ secrets.REPO_SCOPED_TOKEN }} 43 | with: 44 | tag_name: ${{ steps.next_version.outputs.version }} 45 | release_name: Release ${{ steps.next_version.outputs.version }} 46 | body: | 47 | Automatic release 48 | homeassistant version: ${{ steps.current-ha-version.outputs.current-ha-version }} 49 | draft: false 50 | prerelease: false 51 | - name: Set up Python 52 | uses: actions/setup-python@v6 53 | with: 54 | python-version: '3.13' 55 | - name: Install dependencies 56 | run: | 57 | python -m pip install --upgrade pip 58 | pip install setuptools wheel twine 59 | - name: Build and publish 60 | env: 61 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 62 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 63 | run: | 64 | python setup.py sdist bdist_wheel 65 | twine upload dist/* 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # this packages temporary folder 132 | tmp_dir/ -------------------------------------------------------------------------------- /generate_phacc/ha.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import git 4 | 5 | from const import clone, TMP_DIR 6 | 7 | 8 | class HAVersion: 9 | def __init__(self, version): 10 | self._version = version 11 | split_version = version.split(".") 12 | try: 13 | self.major = int(split_version[0]) 14 | except ValueError: 15 | self.major = 0 16 | self.minor = 0 17 | self.patch = 0 18 | self.beta = None 19 | if self.major < 2021: 20 | return 21 | 22 | if len(split_version)>=2: 23 | self.minor = int(split_version[1]) 24 | if len(split_version)>=3: 25 | patch = split_version[2].split("b") 26 | self.patch = int(patch[0]) 27 | if len(patch) == 2: 28 | self.beta = int(patch[1]) 29 | 30 | 31 | def __eq__(self, other): 32 | if ( 33 | self.major==other.major 34 | and self.minor==self.minor 35 | and self.patch==self.patch 36 | and self.beta==self.beta 37 | ): 38 | return True 39 | return False 40 | 41 | 42 | def __gt__(self, other): 43 | if self.major > other.major: 44 | return True 45 | elif self.major < other.major: 46 | return False 47 | 48 | 49 | if self.minor > other.minor: 50 | return True 51 | elif self.minor < other.minor: 52 | return False 53 | 54 | 55 | if self.patch > other.patch: 56 | return True 57 | elif self.patch < other.patch: 58 | return False 59 | 60 | 61 | if self.beta is not None and other.beta is None: 62 | return False 63 | elif self.beta is None and other.beta is not None: 64 | return True 65 | elif self.beta is None and other.beta is None: 66 | return False 67 | elif self.beta > other.beta: 68 | return True 69 | return False 70 | 71 | def prepare_homeassistant(ref=None): 72 | if not os.path.isdir(TMP_DIR): 73 | os.system(clone) # Cloning 74 | 75 | if ref is None: 76 | repo = git.Repo(TMP_DIR) 77 | versions = {str(tag): HAVersion(str(tag)) for tag in repo.tags} 78 | latest_version = HAVersion("0.0.0") 79 | for key, version in versions.items(): 80 | if version > latest_version: 81 | latest_version = version 82 | ref = key 83 | 84 | repo.head.reference = repo.refs[ref] 85 | repo.head.reset(index=True, working_tree=True) 86 | return ref 87 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/diagnostics/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the Diagnostics integration. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from http import HTTPStatus 8 | from typing import cast 9 | 10 | from homeassistant.config_entries import ConfigEntry 11 | from homeassistant.core import HomeAssistant 12 | from homeassistant.helpers.device_registry import DeviceEntry 13 | from homeassistant.setup import async_setup_component 14 | from homeassistant.util.json import JsonObjectType 15 | 16 | from pytest_homeassistant_custom_component.typing import ClientSessionGenerator 17 | 18 | 19 | async def _get_diagnostics_for_config_entry( 20 | hass: HomeAssistant, 21 | hass_client: ClientSessionGenerator, 22 | config_entry: ConfigEntry, 23 | ) -> JsonObjectType: 24 | """Return the diagnostics config entry for the specified domain.""" 25 | assert await async_setup_component(hass, "diagnostics", {}) 26 | await hass.async_block_till_done() 27 | 28 | client = await hass_client() 29 | response = await client.get( 30 | f"/api/diagnostics/config_entry/{config_entry.entry_id}" 31 | ) 32 | assert response.status == HTTPStatus.OK 33 | return cast(JsonObjectType, await response.json()) 34 | 35 | 36 | async def get_diagnostics_for_config_entry( 37 | hass: HomeAssistant, 38 | hass_client: ClientSessionGenerator, 39 | config_entry: ConfigEntry, 40 | ) -> JsonObjectType: 41 | """Return the diagnostics config entry for the specified domain.""" 42 | data = await _get_diagnostics_for_config_entry(hass, hass_client, config_entry) 43 | return cast(JsonObjectType, data["data"]) 44 | 45 | 46 | async def _get_diagnostics_for_device( 47 | hass: HomeAssistant, 48 | hass_client: ClientSessionGenerator, 49 | config_entry: ConfigEntry, 50 | device: DeviceEntry, 51 | ) -> JsonObjectType: 52 | """Return the diagnostics for the specified device.""" 53 | assert await async_setup_component(hass, "diagnostics", {}) 54 | 55 | client = await hass_client() 56 | response = await client.get( 57 | f"/api/diagnostics/config_entry/{config_entry.entry_id}/device/{device.id}" 58 | ) 59 | assert response.status == HTTPStatus.OK 60 | return cast(JsonObjectType, await response.json()) 61 | 62 | 63 | async def get_diagnostics_for_device( 64 | hass: HomeAssistant, 65 | hass_client: ClientSessionGenerator, 66 | config_entry: ConfigEntry, 67 | device: DeviceEntry, 68 | ) -> JsonObjectType: 69 | """Return the diagnostics for the specified device.""" 70 | data = await _get_diagnostics_for_device(hass, hass_client, config_entry, device) 71 | return cast(JsonObjectType, data["data"]) 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytest-homeassistant-custom-component 2 | 3 | ![HA core version](https://img.shields.io/static/v1?label=HA+core+version&message=2025.12.2&labelColor=blue) 4 | 5 | Package to automatically extract testing plugins from Home Assistant for custom component testing. 6 | The goal is to provide the same functionality as the tests in home-assistant/core. 7 | pytest-homeassistant-custom-component is updated daily according to the latest homeassistant release including beta. 8 | 9 | ## Usage: 10 | * All pytest fixtures can be used as normal, like `hass` 11 | * For helpers: 12 | * home-assistant/core native test: `from tests.common import MockConfigEntry` 13 | * custom component test: `from pytest_homeassistant_custom_component.common import MockConfigEntry` 14 | * If your integration is inside a `custom_components` folder, a `custom_components/__init__.py` file or changes to `sys.path` may be required. 15 | * `enable_custom_integrations` fixture is required (versions >=2021.6.0b0) 16 | * Some fixtures, e.g. `recorder_mock`, need to be initialized before `enable_custom_integrations`. See https://github.com/MatthewFlamm/pytest-homeassistant-custom-component/issues/132. 17 | * pytest-asyncio might now require `asyncio_mode = auto` config, see #129. 18 | * If using `load_fixture`, the files need to be in a `fixtures` folder colocated with the tests. For example, a test in `test_sensor.py` can load data from `some_data.json` using `load_fixture` from this structure: 19 | 20 | ``` 21 | tests/ 22 | fixtures/ 23 | some_data.json 24 | test_sensor.py 25 | ``` 26 | 27 | * When using syrupy snapshots, add a `snapshot` fixture to conftest.py to make sure the snapshots are loaded from snapshot folder colocated with the tests. 28 | 29 | ```py 30 | from pytest_homeassistant_custom_component.syrupy import HomeAssistantSnapshotExtension 31 | from syrupy.assertion import SnapshotAssertion 32 | 33 | 34 | @pytest.fixture 35 | def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: 36 | """Return snapshot assertion fixture with the Home Assistant extension.""" 37 | return snapshot.use_extension(HomeAssistantSnapshotExtension) 38 | ``` 39 | 40 | ## Examples: 41 | * See [list of custom components](https://github.com/MatthewFlamm/pytest-homeassistant-custom-component/network/dependents) as examples that use this package. 42 | * Also see tests for `simple_integration` in this repository. 43 | * Use [cookiecutter-homeassistant-custom-component](https://github.com/oncleben31/cookiecutter-homeassistant-custom-component) to create a custom component with tests by using [cookiecutter](https://github.com/cookiecutter/cookiecutter). 44 | * The [github-custom-component-tutorial](https://github.com/boralyl/github-custom-component-tutorial) explaining in details how to create a custom componenent with a test suite using this package. 45 | 46 | ## More Info 47 | This repository is set up to be nearly fully automatic. 48 | 49 | * Version of home-assistant/core is given in `ha_version`, `pytest_homeassistant_custom_component.const`, and in the README above. 50 | * This package is generated against published releases of homeassistant and updated daily. 51 | * PRs should not include changes to the `pytest_homeassistant_custom_component` files. CI testing will automatically generate the new files. 52 | 53 | ### Version Strategy 54 | * When changes in extraction are required, there will be a change in the minor version. 55 | * A change in the patch version indicates that it was an automatic update with a homeassistant version. 56 | * This enables tracking back to which versions of pytest-homeassistant-custom-component can be used for 57 | extracting testing utilities from which version of homeassistant. 58 | 59 | This package was inspired by [pytest-homeassistant](https://github.com/boralyl/pytest-homeassistant) by @boralyl, but is intended to more closely and automatically track the home-assistant/core library. 60 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/patch_time.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch time related functions. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import datetime 10 | import time 11 | 12 | import freezegun 13 | 14 | 15 | def ha_datetime_to_fakedatetime(datetime) -> freezegun.api.FakeDatetime: # type: ignore[name-defined] 16 | """Convert datetime to FakeDatetime. 17 | 18 | Modified to include https://github.com/spulec/freezegun/pull/424. 19 | """ 20 | return freezegun.api.FakeDatetime( # type: ignore[attr-defined] 21 | datetime.year, 22 | datetime.month, 23 | datetime.day, 24 | datetime.hour, 25 | datetime.minute, 26 | datetime.second, 27 | datetime.microsecond, 28 | datetime.tzinfo, 29 | fold=datetime.fold, 30 | ) 31 | 32 | 33 | class HAFakeDateMeta(freezegun.api.FakeDateMeta): 34 | """Modified to override the string representation.""" 35 | 36 | def __str__(cls) -> str: # noqa: N805 (ruff doesn't know this is a metaclass) 37 | """Return the string representation of the class.""" 38 | return "" 39 | 40 | 41 | class HAFakeDate(freezegun.api.FakeDate, metaclass=HAFakeDateMeta): # type: ignore[name-defined] 42 | """Modified to improve class str.""" 43 | 44 | 45 | class HAFakeDatetimeMeta(freezegun.api.FakeDatetimeMeta): 46 | """Modified to override the string representation.""" 47 | 48 | def __str__(cls) -> str: # noqa: N805 (ruff doesn't know this is a metaclass) 49 | """Return the string representation of the class.""" 50 | return "" 51 | 52 | 53 | class HAFakeDatetime(freezegun.api.FakeDatetime, metaclass=HAFakeDatetimeMeta): # type: ignore[name-defined] 54 | """Modified to include basic fold support and improve class str. 55 | 56 | Fold support submitted to upstream in https://github.com/spulec/freezegun/pull/424. 57 | """ 58 | 59 | @classmethod 60 | def now(cls, tz=None): 61 | """Return frozen now.""" 62 | now = cls._time_to_freeze() or freezegun.api.real_datetime.now() 63 | if tz: 64 | result = tz.fromutc(now.replace(tzinfo=tz)) 65 | else: 66 | result = now 67 | 68 | # Add the _tz_offset only if it's non-zero to preserve fold 69 | if cls._tz_offset(): 70 | result += cls._tz_offset() 71 | 72 | return ha_datetime_to_fakedatetime(result) 73 | 74 | 75 | # Needed by Mashumaro 76 | datetime.HAFakeDatetime = HAFakeDatetime 77 | 78 | # Do not add any Home Assistant import here 79 | 80 | 81 | def _utcnow() -> datetime.datetime: 82 | """Make utcnow patchable by freezegun.""" 83 | return datetime.datetime.now(datetime.UTC) 84 | 85 | 86 | def _monotonic() -> float: 87 | """Make monotonic patchable by freezegun.""" 88 | return time.monotonic() 89 | 90 | 91 | # Before importing any other Home Assistant functionality, import and replace 92 | # partial dt_util.utcnow with a regular function which can be found by freezegun 93 | from homeassistant import util # noqa: E402 94 | from homeassistant.util import dt as dt_util # noqa: E402 95 | 96 | dt_util.utcnow = _utcnow # type: ignore[assignment] 97 | util.utcnow = _utcnow # type: ignore[assignment] 98 | 99 | 100 | # Import other Home Assistant functionality which we need to patch 101 | from homeassistant import runner # noqa: E402 102 | from homeassistant.helpers import event as event_helper # noqa: E402 103 | 104 | # Replace partial functions which are not found by freezegun 105 | event_helper.time_tracker_utcnow = _utcnow # type: ignore[assignment] 106 | 107 | # Replace bound methods which are not found by freezegun 108 | runner.monotonic = _monotonic # type: ignore[assignment] 109 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/asyncio_legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal legacy asyncio.coroutine. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | # flake8: noqa 8 | # stubbing out for integrations that have 9 | # not yet been updated for python 3.11 10 | # but can still run on python 3.10 11 | # 12 | # Remove this once rflink, fido, and blackbird 13 | # have had their libraries updated to remove 14 | # asyncio.coroutine 15 | from asyncio import base_futures, constants, format_helpers 16 | from asyncio.coroutines import _is_coroutine 17 | import collections.abc 18 | import functools 19 | import inspect 20 | import logging 21 | import traceback 22 | import types 23 | import warnings 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class CoroWrapper: 29 | # Wrapper for coroutine object in _DEBUG mode. 30 | 31 | def __init__(self, gen, func=None): 32 | assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen 33 | self.gen = gen 34 | self.func = func # Used to unwrap @coroutine decorator 35 | self._source_traceback = format_helpers.extract_stack(sys._getframe(1)) 36 | self.__name__ = getattr(gen, "__name__", None) 37 | self.__qualname__ = getattr(gen, "__qualname__", None) 38 | 39 | def __iter__(self): 40 | return self 41 | 42 | def __next__(self): 43 | return self.gen.send(None) 44 | 45 | def send(self, value): 46 | return self.gen.send(value) 47 | 48 | def throw(self, type, value=None, traceback=None): 49 | return self.gen.throw(type, value, traceback) 50 | 51 | def close(self): 52 | return self.gen.close() 53 | 54 | @property 55 | def gi_frame(self): 56 | return self.gen.gi_frame 57 | 58 | @property 59 | def gi_running(self): 60 | return self.gen.gi_running 61 | 62 | @property 63 | def gi_code(self): 64 | return self.gen.gi_code 65 | 66 | def __await__(self): 67 | return self 68 | 69 | @property 70 | def gi_yieldfrom(self): 71 | return self.gen.gi_yieldfrom 72 | 73 | def __del__(self): 74 | # Be careful accessing self.gen.frame -- self.gen might not exist. 75 | gen = getattr(self, "gen", None) 76 | frame = getattr(gen, "gi_frame", None) 77 | if frame is not None and frame.f_lasti == -1: 78 | msg = f"{self!r} was never yielded from" 79 | tb = getattr(self, "_source_traceback", ()) 80 | if tb: 81 | tb = "".join(traceback.format_list(tb)) 82 | msg += ( 83 | f"\nCoroutine object created at " 84 | f"(most recent call last, truncated to " 85 | f"{constants.DEBUG_STACK_DEPTH} last lines):\n" 86 | ) 87 | msg += tb.rstrip() 88 | logger.error(msg) 89 | 90 | 91 | def legacy_coroutine(func): 92 | """Decorator to mark coroutines. 93 | If the coroutine is not yielded from before it is destroyed, 94 | an error message is logged. 95 | """ 96 | warnings.warn( 97 | '"@coroutine" decorator is deprecated since Python 3.8, use "async def" instead', 98 | DeprecationWarning, 99 | stacklevel=2, 100 | ) 101 | if inspect.iscoroutinefunction(func): 102 | # In Python 3.5 that's all we need to do for coroutines 103 | # defined with "async def". 104 | return func 105 | 106 | if inspect.isgeneratorfunction(func): 107 | coro = func 108 | else: 109 | 110 | @functools.wraps(func) 111 | def coro(*args, **kw): 112 | res = func(*args, **kw) 113 | if ( 114 | base_futures.isfuture(res) 115 | or inspect.isgenerator(res) 116 | or isinstance(res, CoroWrapper) 117 | ): 118 | res = yield from res 119 | else: 120 | # If 'res' is an awaitable, run it. 121 | try: 122 | await_meth = res.__await__ 123 | except AttributeError: 124 | pass 125 | else: 126 | if isinstance(res, collections.abc.Awaitable): 127 | res = yield from await_meth() 128 | return res 129 | 130 | wrapper = types.coroutine(coro) 131 | wrapper._is_coroutine = _is_coroutine # For iscoroutinefunction(). 132 | return wrapper 133 | -------------------------------------------------------------------------------- /.github/workflows/automatic_generation.yml: -------------------------------------------------------------------------------- 1 | name: Automatic Generate 2 | 3 | on: 4 | schedule: 5 | - cron: "0 5 * * *" 6 | workflow_dispatch: 7 | 8 | jobs: 9 | generate_package: 10 | runs-on: "ubuntu-latest" 11 | outputs: 12 | current_ha_version: ${{ steps.current-ha-version.outputs.current-ha-version }} 13 | new_ha_version: ${{ steps.new-ha-version.outputs.new-ha-version }} 14 | need_to_release: ${{ steps.need-to-release.outputs.need-to-release }} 15 | steps: 16 | - name: checkout repo content 17 | uses: actions/checkout@v5 18 | - name: store current ha version 19 | id: current-ha-version 20 | run: echo "::set-output name=current-ha-version::$(cat ha_version)" 21 | - name: setup python 22 | uses: actions/setup-python@v6 23 | with: 24 | python-version: '3.13' 25 | - name: install dependencies 26 | run: pip install -r requirements_generate.txt 27 | - name: Install phacc for current versions 28 | run: pip install -e . 29 | - name: execute generate package 30 | run: | 31 | export PYTHONPATH=$PYTHONPATH:$(pwd) 32 | python generate_phacc/generate_phacc.py 33 | - name: store new ha version 34 | id: new-ha-version 35 | run: echo "::set-output name=new-ha-version::$(cat ha_version)" 36 | - name: check need to release 37 | id: need-to-release 38 | run: | 39 | if [[ "${{ steps.current-ha-version.outputs.current-ha-version}}" == "${{ steps.new-ha-version.outputs.new-ha-version }}" ]]; then 40 | echo "::set-output name=need-to-release::false" 41 | else 42 | echo "::set-output name=need-to-release::true" 43 | fi 44 | - name: list files 45 | run: ls -a 46 | - name: publish artifact 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: generated-package 50 | path: | 51 | ./ 52 | !**/*.pyc 53 | !tmp_dir/ 54 | !.git/ 55 | if-no-files-found: error 56 | test: 57 | needs: generate_package 58 | runs-on: "ubuntu-latest" 59 | if: needs.generate_package.outputs.need_to_release == 'true' 60 | strategy: 61 | matrix: 62 | python-version: ['3.13'] 63 | steps: 64 | - name: checkout repo content 65 | uses: actions/checkout@v5 66 | - name: download artifact 67 | uses: actions/download-artifact@v5 68 | with: 69 | name: generated-package 70 | - name: Set up Python ${{ matrix.python-version }} 71 | uses: actions/setup-python@v6 72 | with: 73 | python-version: ${{ matrix.python-version }} 74 | - name: Install dependencies 75 | run: | 76 | python -m pip install --upgrade pip 77 | pip install -e . 78 | - name: Test with pytest 79 | run: | 80 | pytest 81 | make_release: 82 | needs: [generate_package, test] 83 | runs-on: "ubuntu-latest" 84 | if: needs.generate_package.outputs.need_to_release == 'true' 85 | steps: 86 | - uses: actions/checkout@v5 87 | - name: download artifact 88 | uses: actions/download-artifact@v5 89 | with: 90 | name: generated-package 91 | - name: need_to_release_print 92 | run: "echo ${{ needs.generate_package.outputs.need_to_release }}" 93 | - id: next_version 94 | uses: zwaldowski/semver-release-action@v4 95 | with: 96 | dry_run: true 97 | bump: patch 98 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 99 | - run: echo "${{ steps.next_version.outputs.version }}" > version 100 | - run: echo "${{ steps.next_version.outputs.version }}" 101 | - id: git_commit 102 | run: | 103 | git config user.name 'Matthew Flamm' 104 | git config user.email 'MatthewFlamm@users.noreply.github.com' 105 | git add . 106 | git commit -m "Bump version" 107 | git push 108 | echo "::set-output name=sha::$(git rev-parse HEAD)" 109 | - uses: zwaldowski/semver-release-action@v4 110 | with: 111 | github_token: ${{ secrets.REPO_SCOPED_TOKEN }} 112 | sha: ${{ steps.git_commit.outputs.sha }} 113 | - name: Create Release 114 | id: create_release 115 | uses: actions/create-release@v1 116 | env: 117 | GITHUB_TOKEN: ${{ secrets.REPO_SCOPED_TOKEN }} 118 | with: 119 | tag_name: ${{ steps.next_version.outputs.version }} 120 | release_name: Release ${{ steps.next_version.outputs.version }} 121 | body: | 122 | Automatic release 123 | homeassistant version: ${{ needs.generate_package.outputs.new_ha_version }} 124 | draft: false 125 | prerelease: false 126 | - name: Set up Python 127 | uses: actions/setup-python@v6 128 | with: 129 | python-version: '3.13' 130 | - name: Install dependencies 131 | run: | 132 | python -m pip install --upgrade pip 133 | pip install setuptools wheel twine 134 | - name: Build 135 | run: | 136 | python setup.py sdist bdist_wheel 137 | - name: Publish distribution 📦 to PyPI 138 | uses: pypa/gh-action-pypi-publish@release/v1 139 | with: 140 | password: ${{ secrets.PYPI_TOKEN }} 141 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/recorder/db_schema_0.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models for SQLAlchemy. 3 | 4 | This file contains the original models definitions before schema tracking was 5 | implemented. It is used to test the schema migration logic. 6 | 7 | 8 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 9 | """ 10 | 11 | import json 12 | import logging 13 | 14 | from sqlalchemy import ( 15 | Boolean, 16 | Column, 17 | DateTime, 18 | ForeignKey, 19 | Index, 20 | Integer, 21 | String, 22 | Text, 23 | distinct, 24 | ) 25 | from sqlalchemy.orm import declarative_base 26 | from sqlalchemy.orm.session import Session 27 | 28 | from homeassistant.core import Event, EventOrigin, State, split_entity_id 29 | from homeassistant.helpers.json import JSONEncoder 30 | from homeassistant.util import dt as dt_util 31 | 32 | # SQLAlchemy Schema 33 | Base = declarative_base() 34 | 35 | _LOGGER = logging.getLogger(__name__) 36 | 37 | 38 | class Events(Base): # type: ignore[valid-type,misc] 39 | """Event history data.""" 40 | 41 | __tablename__ = "events" 42 | event_id = Column(Integer, primary_key=True) 43 | event_type = Column(String(32), index=True) 44 | event_data = Column(Text) 45 | origin = Column(String(32)) 46 | time_fired = Column(DateTime(timezone=True)) 47 | created = Column(DateTime(timezone=True), default=dt_util.utcnow) 48 | 49 | @staticmethod 50 | def from_event(event): 51 | """Create an event database object from a native event.""" 52 | return Events( 53 | event_type=event.event_type, 54 | event_data=json.dumps(event.data, cls=JSONEncoder), 55 | origin=str(event.origin), 56 | time_fired=event.time_fired, 57 | ) 58 | 59 | def to_native(self): 60 | """Convert to a natve HA Event.""" 61 | try: 62 | return Event( 63 | self.event_type, 64 | json.loads(self.event_data), 65 | EventOrigin(self.origin), 66 | _process_timestamp(self.time_fired), 67 | ) 68 | except ValueError: 69 | # When json.loads fails 70 | _LOGGER.exception("Error converting to event: %s", self) 71 | return None 72 | 73 | 74 | class States(Base): # type: ignore[valid-type,misc] 75 | """State change history.""" 76 | 77 | __tablename__ = "states" 78 | state_id = Column(Integer, primary_key=True) 79 | domain = Column(String(64)) 80 | entity_id = Column(String(255)) 81 | state = Column(String(255)) 82 | attributes = Column(Text) 83 | event_id = Column(Integer, ForeignKey("events.event_id")) 84 | last_changed = Column(DateTime(timezone=True), default=dt_util.utcnow) 85 | last_updated = Column(DateTime(timezone=True), default=dt_util.utcnow) 86 | created = Column(DateTime(timezone=True), default=dt_util.utcnow) 87 | 88 | __table_args__ = ( 89 | Index("states__state_changes", "last_changed", "last_updated", "entity_id"), 90 | Index("states__significant_changes", "domain", "last_updated", "entity_id"), 91 | ) 92 | 93 | @staticmethod 94 | def from_event(event): 95 | """Create object from a state_changed event.""" 96 | entity_id = event.data["entity_id"] 97 | state = event.data.get("new_state") 98 | 99 | dbstate = States(entity_id=entity_id) 100 | 101 | # State got deleted 102 | if state is None: 103 | dbstate.state = "" 104 | dbstate.domain = split_entity_id(entity_id)[0] 105 | dbstate.attributes = "{}" 106 | dbstate.last_changed = event.time_fired 107 | dbstate.last_updated = event.time_fired 108 | else: 109 | dbstate.domain = state.domain 110 | dbstate.state = state.state 111 | dbstate.attributes = json.dumps(dict(state.attributes), cls=JSONEncoder) 112 | dbstate.last_changed = state.last_changed 113 | dbstate.last_updated = state.last_updated 114 | 115 | return dbstate 116 | 117 | def to_native(self): 118 | """Convert to an HA state object.""" 119 | try: 120 | return State( 121 | self.entity_id, 122 | self.state, 123 | json.loads(self.attributes), 124 | _process_timestamp(self.last_changed), 125 | _process_timestamp(self.last_updated), 126 | ) 127 | except ValueError: 128 | # When json.loads fails 129 | _LOGGER.exception("Error converting row to state: %s", self) 130 | return None 131 | 132 | 133 | class RecorderRuns(Base): # type: ignore[valid-type,misc] 134 | """Representation of recorder run.""" 135 | 136 | __tablename__ = "recorder_runs" 137 | run_id = Column(Integer, primary_key=True) 138 | start = Column(DateTime(timezone=True), default=dt_util.utcnow) 139 | end = Column(DateTime(timezone=True)) 140 | closed_incorrect = Column(Boolean, default=False) 141 | created = Column(DateTime(timezone=True), default=dt_util.utcnow) 142 | 143 | def entity_ids(self, point_in_time=None): 144 | """Return the entity ids that existed in this run. 145 | 146 | Specify point_in_time if you want to know which existed at that point 147 | in time inside the run. 148 | """ 149 | session = Session.object_session(self) 150 | 151 | assert session is not None, "RecorderRuns need to be persisted" 152 | 153 | query = session.query(distinct(States.entity_id)).filter( 154 | States.last_updated >= self.start 155 | ) 156 | 157 | if point_in_time is not None: 158 | query = query.filter(States.last_updated < point_in_time) 159 | elif self.end is not None: 160 | query = query.filter(States.last_updated < self.end) 161 | 162 | return [row[0] for row in query] 163 | 164 | def to_native(self): 165 | """Return self, native format is this model.""" 166 | return self 167 | 168 | 169 | def _process_timestamp(ts): 170 | """Process a timestamp into datetime object.""" 171 | if ts is None: 172 | return None 173 | if ts.tzinfo is None: 174 | return ts.replace(tzinfo=dt_util.UTC) 175 | return dt_util.as_utc(ts) 176 | -------------------------------------------------------------------------------- /generate_phacc/generate_phacc.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import pathlib 4 | import re 5 | import shutil 6 | import os 7 | 8 | import click 9 | import git 10 | 11 | from ha import prepare_homeassistant 12 | from const import ( 13 | TMP_DIR, 14 | PACKAGE_DIR, 15 | REQUIREMENTS_FILE, 16 | CONST_FILE, 17 | REQUIREMENTS_FILE_DEV, 18 | LICENSE_FILE_HA, 19 | LICENSE_FILE_NEW, 20 | path, 21 | files, 22 | requirements_remove, 23 | HA_VERSION_FILE, 24 | ) 25 | 26 | @click.command 27 | @click.option("--regen/--no-regen", default=False, help="Whether to regenerate despite version") 28 | def cli(regen): 29 | if os.path.isdir(PACKAGE_DIR): 30 | shutil.rmtree(PACKAGE_DIR) 31 | if os.path.isfile(REQUIREMENTS_FILE): 32 | os.remove(REQUIREMENTS_FILE) 33 | 34 | ha_version = prepare_homeassistant() 35 | 36 | with open(HA_VERSION_FILE, "r") as f: 37 | current_version = f.read() 38 | print(f"Current Version: {current_version}") 39 | 40 | 41 | def process_files(): 42 | os.mkdir(PACKAGE_DIR) 43 | os.mkdir(os.path.join(PACKAGE_DIR, "test_util")) 44 | os.makedirs(os.path.join(PACKAGE_DIR, "components", "recorder")) 45 | os.makedirs(os.path.join(PACKAGE_DIR, "components", "diagnostics")) 46 | os.makedirs(os.path.join(PACKAGE_DIR, "testing_config", "custom_components", "test_constant_deprecation")) 47 | shutil.copy2(os.path.join(TMP_DIR, REQUIREMENTS_FILE), REQUIREMENTS_FILE) 48 | shutil.copy2( 49 | os.path.join(TMP_DIR, "homeassistant", CONST_FILE), 50 | os.path.join(PACKAGE_DIR, CONST_FILE), 51 | ) 52 | shutil.copy2( 53 | os.path.join(TMP_DIR, "tests", "test_util", "aiohttp.py"), 54 | os.path.join(PACKAGE_DIR, "test_util", "aiohttp.py"), 55 | ) 56 | shutil.copy2( 57 | os.path.join(TMP_DIR, "tests", "test_util", "__init__.py"), 58 | os.path.join(PACKAGE_DIR, "test_util", "__init__.py"), 59 | ) 60 | shutil.copy2( 61 | os.path.join(TMP_DIR, "tests", "components", "recorder", "common.py"), 62 | os.path.join(PACKAGE_DIR, "components", "recorder", "common.py"), 63 | ) 64 | shutil.copy2( 65 | os.path.join(TMP_DIR, "tests", "components", "recorder", "db_schema_0.py"), 66 | os.path.join(PACKAGE_DIR, "components", "recorder", "db_schema_0.py"), 67 | ) 68 | shutil.copy2( 69 | os.path.join(TMP_DIR, "tests", "components", "recorder", "__init__.py"), 70 | os.path.join(PACKAGE_DIR, "components", "recorder", "__init__.py"), 71 | ) 72 | shutil.copy2( 73 | os.path.join(TMP_DIR, "tests", "components", "diagnostics", "__init__.py"), 74 | os.path.join(PACKAGE_DIR, "components", "diagnostics", "__init__.py"), 75 | ) 76 | shutil.copy2( 77 | os.path.join(TMP_DIR, "tests", "testing_config", "__init__.py"), 78 | os.path.join(PACKAGE_DIR, "testing_config", "__init__.py"), 79 | ) 80 | shutil.copy2( 81 | os.path.join(TMP_DIR, "tests", "testing_config", "custom_components", "__init__.py"), 82 | os.path.join(PACKAGE_DIR, "testing_config", "custom_components", "__init__.py"), 83 | ) 84 | shutil.copy2( 85 | os.path.join(TMP_DIR, "tests", "testing_config", "custom_components", "test_constant_deprecation", "__init__.py"), 86 | os.path.join(PACKAGE_DIR, "testing_config", "custom_components", "test_constant_deprecation", "__init__.py"), 87 | ) 88 | shutil.copy2( 89 | os.path.join(TMP_DIR, LICENSE_FILE_HA), 90 | LICENSE_FILE_NEW, 91 | ) 92 | 93 | for f in files: 94 | shutil.copy2(os.path.join(TMP_DIR, "tests", f), os.path.join(PACKAGE_DIR, f)) 95 | 96 | filename = os.path.join(PACKAGE_DIR, f) 97 | 98 | with open(filename, "r") as file: 99 | filedata = file.read() 100 | 101 | filedata = filedata.replace( 102 | "tests.", "." * (f.count("/") + 1) 103 | ) # Add dots depending on depth 104 | 105 | with open(filename, "w") as file: 106 | file.write(filedata) 107 | 108 | os.rename( 109 | os.path.join(PACKAGE_DIR, "conftest.py"), 110 | os.path.join(PACKAGE_DIR, "plugins.py"), 111 | ) 112 | 113 | with open(os.path.join(PACKAGE_DIR, CONST_FILE), "r") as original_file: 114 | data = original_file.readlines() 115 | new_data = [d for d in data[:100] if "version" in d.lower() or "from typing" in d] 116 | new_data.insert(0, data[0]) 117 | 118 | with open(os.path.join(PACKAGE_DIR, CONST_FILE), "w") as new_file: 119 | new_file.write("".join(new_data)) 120 | 121 | added_text = "This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component.\n" 122 | triple_quote = '"""\n' 123 | 124 | for f in pathlib.Path(PACKAGE_DIR).rglob("*.py"): 125 | with open(f, "r") as original_file: 126 | data = original_file.readlines() 127 | 128 | multiline_docstring = not data[0].endswith(triple_quote) 129 | line_after_docstring = 1 130 | old_docstring = "" 131 | if not multiline_docstring: 132 | old_docstring = data[0][3:][:-4] 133 | else: 134 | old_docstring = data[0][3:] 135 | while data[line_after_docstring] != triple_quote: 136 | old_docstring += data[line_after_docstring] 137 | line_after_docstring += 1 138 | line_after_docstring += 1 # Skip last triplequote 139 | 140 | new_docstring = f"{triple_quote}{old_docstring}\n\n{added_text}{triple_quote}" 141 | body = "".join(data[line_after_docstring:]) 142 | with open(f, "w") as new_file: 143 | new_file.write("".join([new_docstring, body])) 144 | 145 | added_text = "# This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component.\n" 146 | 147 | with open(REQUIREMENTS_FILE, "r") as original_file: 148 | data = original_file.readlines() 149 | 150 | def is_test_requirement(requirement): 151 | # if == not in d this is either a comment or unkown package, include 152 | if "==" not in requirement: 153 | return True 154 | 155 | regex = re.compile("types-.+") 156 | if re.match(regex, requirement): 157 | return False 158 | 159 | if d.split("==")[0] in requirements_remove: 160 | return False 161 | 162 | return True 163 | 164 | new_data = [] 165 | removed_data = [] 166 | for d in data: 167 | if is_test_requirement(d): 168 | new_data.append(d) 169 | else: 170 | removed_data.append(d) 171 | new_data.append(f"homeassistant=={ha_version}\n") 172 | new_data.insert(0, added_text) 173 | 174 | def find_dependency(dependency, data): 175 | for d in data: 176 | if dependency in d.lower(): 177 | return d 178 | raise ValueError(f"could not find {dependency}") 179 | 180 | with open(os.path.join(TMP_DIR, "requirements_all.txt"), "r") as f: 181 | data = f.readlines() 182 | 183 | def add_dependency(dependency, ha_data, new_data): 184 | dep = find_dependency(dependency, data) 185 | if not "\n" == dep[-2:]: 186 | dep = f"{dep}\n" 187 | new_data.append(dep) 188 | 189 | add_dependency("sqlalchemy", data, new_data) 190 | add_dependency("paho-mqtt", data, new_data) 191 | add_dependency("numpy", data, new_data) 192 | 193 | removed_data.insert(0, added_text) 194 | 195 | with open(REQUIREMENTS_FILE, "w") as new_file: 196 | new_file.writelines(new_data) 197 | 198 | with open(REQUIREMENTS_FILE_DEV, "w") as new_file: 199 | new_file.writelines(removed_data) 200 | 201 | from pytest_homeassistant_custom_component.const import __version__ 202 | 203 | with open("README.md", "r") as original_file: 204 | data = original_file.readlines() 205 | 206 | data[ 207 | 2 208 | ] = f"![HA core version](https://img.shields.io/static/v1?label=HA+core+version&message={__version__}&labelColor=blue)\n" 209 | 210 | with open("README.md", "w") as new_file: 211 | new_file.write("".join(data)) 212 | 213 | print(f"New Version: {__version__}") 214 | 215 | # modify load_fixture 216 | with open(os.path.join(PACKAGE_DIR, "common.py"), "r") as original_file: 217 | data = original_file.readlines() 218 | 219 | import_time_lineno = [i for i, line in enumerate(data) if "import time" in line] 220 | assert len(import_time_lineno) == 1 221 | data.insert(import_time_lineno[0] + 1, "import traceback\n") 222 | 223 | fixture_path_lineno = [ 224 | i for i, line in enumerate(data) if "def get_fixture_path" in line 225 | ] 226 | assert len(fixture_path_lineno) == 1 227 | data.insert( 228 | fixture_path_lineno[0] + 2, 229 | " start_path = (current_file := traceback.extract_stack()[idx:=-1].filename)\n", 230 | ) 231 | data.insert( 232 | fixture_path_lineno[0] + 3, 233 | " while start_path == current_file:\n", 234 | ) 235 | data.insert( 236 | fixture_path_lineno[0] + 4, 237 | " start_path = traceback.extract_stack()[idx:=idx-1].filename\n", 238 | ) 239 | data[fixture_path_lineno[0] + 9] = data[fixture_path_lineno[0] + 9].replace( 240 | "__file__", "start_path" 241 | ) 242 | data[fixture_path_lineno[0] + 11] = data[fixture_path_lineno[0] + 11].replace( 243 | "__file__", "start_path" 244 | ) 245 | 246 | with open(os.path.join(PACKAGE_DIR, "common.py"), "w") as new_file: 247 | new_file.writelines(data) 248 | 249 | # modify diagnostics file 250 | with open(os.path.join(PACKAGE_DIR, "components", "diagnostics", "__init__.py"), "r") as original_file: 251 | data = original_file.readlines() 252 | 253 | diagnostics_lineno = [ 254 | i for i, line in enumerate(data) if "from tests.typing" in line 255 | ] 256 | assert len(diagnostics_lineno) == 1 257 | data[diagnostics_lineno[0]] = data[diagnostics_lineno[0]].replace( 258 | "tests.typing","pytest_homeassistant_custom_component.typing" 259 | ) 260 | 261 | with open(os.path.join(PACKAGE_DIR, "components", "diagnostics", "__init__.py"), "w") as new_file: 262 | new_file.writelines(data) 263 | 264 | 265 | if ha_version != current_version or regen: 266 | process_files() 267 | with open(HA_VERSION_FILE, "w") as f: 268 | f.write(ha_version) 269 | else: 270 | print("Already up to date") 271 | 272 | if __name__=="__main__": 273 | cli() -------------------------------------------------------------------------------- /LICENSE_HA_CORE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The tests for components. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from enum import StrEnum 8 | import itertools 9 | from typing import TypedDict 10 | 11 | from homeassistant.const import ( 12 | ATTR_AREA_ID, 13 | ATTR_DEVICE_ID, 14 | ATTR_FLOOR_ID, 15 | ATTR_LABEL_ID, 16 | CONF_ENTITY_ID, 17 | CONF_OPTIONS, 18 | CONF_PLATFORM, 19 | CONF_TARGET, 20 | STATE_UNAVAILABLE, 21 | STATE_UNKNOWN, 22 | ) 23 | from homeassistant.core import HomeAssistant 24 | from homeassistant.helpers import ( 25 | area_registry as ar, 26 | device_registry as dr, 27 | entity_registry as er, 28 | floor_registry as fr, 29 | label_registry as lr, 30 | ) 31 | from homeassistant.setup import async_setup_component 32 | 33 | from ..common import MockConfigEntry, mock_device_registry 34 | 35 | 36 | async def target_entities( 37 | hass: HomeAssistant, domain: str 38 | ) -> tuple[list[str], list[str]]: 39 | """Create multiple entities associated with different targets. 40 | 41 | Returns a dict with the following keys: 42 | - included: List of entity_ids meant to be targeted. 43 | - excluded: List of entity_ids not meant to be targeted. 44 | """ 45 | await async_setup_component(hass, domain, {}) 46 | 47 | config_entry = MockConfigEntry(domain="test") 48 | config_entry.add_to_hass(hass) 49 | 50 | floor_reg = fr.async_get(hass) 51 | floor = floor_reg.async_get_floor_by_name("Test Floor") or floor_reg.async_create( 52 | "Test Floor" 53 | ) 54 | 55 | area_reg = ar.async_get(hass) 56 | area = area_reg.async_get_area_by_name("Test Area") or area_reg.async_create( 57 | "Test Area", floor_id=floor.floor_id 58 | ) 59 | 60 | label_reg = lr.async_get(hass) 61 | label = label_reg.async_get_label_by_name("Test Label") or label_reg.async_create( 62 | "Test Label" 63 | ) 64 | 65 | device = dr.DeviceEntry(id="test_device", area_id=area.id, labels={label.label_id}) 66 | mock_device_registry(hass, {device.id: device}) 67 | 68 | entity_reg = er.async_get(hass) 69 | # Entities associated with area 70 | entity_area = entity_reg.async_get_or_create( 71 | domain=domain, 72 | platform="test", 73 | unique_id=f"{domain}_area", 74 | suggested_object_id=f"area_{domain}", 75 | ) 76 | entity_reg.async_update_entity(entity_area.entity_id, area_id=area.id) 77 | entity_area_excluded = entity_reg.async_get_or_create( 78 | domain=domain, 79 | platform="test", 80 | unique_id=f"{domain}_area_excluded", 81 | suggested_object_id=f"area_{domain}_excluded", 82 | ) 83 | entity_reg.async_update_entity(entity_area_excluded.entity_id, area_id=area.id) 84 | 85 | # Entities associated with device 86 | entity_reg.async_get_or_create( 87 | domain=domain, 88 | platform="test", 89 | unique_id=f"{domain}_device", 90 | suggested_object_id=f"device_{domain}", 91 | device_id=device.id, 92 | ) 93 | entity_reg.async_get_or_create( 94 | domain=domain, 95 | platform="test", 96 | unique_id=f"{domain}_device_excluded", 97 | suggested_object_id=f"device_{domain}_excluded", 98 | device_id=device.id, 99 | ) 100 | 101 | # Entities associated with label 102 | entity_label = entity_reg.async_get_or_create( 103 | domain=domain, 104 | platform="test", 105 | unique_id=f"{domain}_label", 106 | suggested_object_id=f"label_{domain}", 107 | ) 108 | entity_reg.async_update_entity(entity_label.entity_id, labels={label.label_id}) 109 | entity_label_excluded = entity_reg.async_get_or_create( 110 | domain=domain, 111 | platform="test", 112 | unique_id=f"{domain}_label_excluded", 113 | suggested_object_id=f"label_{domain}_excluded", 114 | ) 115 | entity_reg.async_update_entity( 116 | entity_label_excluded.entity_id, labels={label.label_id} 117 | ) 118 | 119 | # Return all available entities 120 | return { 121 | "included": [ 122 | f"{domain}.standalone_{domain}", 123 | f"{domain}.label_{domain}", 124 | f"{domain}.area_{domain}", 125 | f"{domain}.device_{domain}", 126 | ], 127 | "excluded": [ 128 | f"{domain}.standalone_{domain}_excluded", 129 | f"{domain}.label_{domain}_excluded", 130 | f"{domain}.area_{domain}_excluded", 131 | f"{domain}.device_{domain}_excluded", 132 | ], 133 | } 134 | 135 | 136 | def parametrize_target_entities(domain: str) -> list[tuple[dict, str, int]]: 137 | """Parametrize target entities for different target types. 138 | 139 | Meant to be used with target_entities. 140 | """ 141 | return [ 142 | ( 143 | {CONF_ENTITY_ID: f"{domain}.standalone_{domain}"}, 144 | f"{domain}.standalone_{domain}", 145 | 1, 146 | ), 147 | ({ATTR_LABEL_ID: "test_label"}, f"{domain}.label_{domain}", 2), 148 | ({ATTR_AREA_ID: "test_area"}, f"{domain}.area_{domain}", 2), 149 | ({ATTR_FLOOR_ID: "test_floor"}, f"{domain}.area_{domain}", 2), 150 | ({ATTR_LABEL_ID: "test_label"}, f"{domain}.device_{domain}", 2), 151 | ({ATTR_AREA_ID: "test_area"}, f"{domain}.device_{domain}", 2), 152 | ({ATTR_FLOOR_ID: "test_floor"}, f"{domain}.device_{domain}", 2), 153 | ({ATTR_DEVICE_ID: "test_device"}, f"{domain}.device_{domain}", 1), 154 | ] 155 | 156 | 157 | class _StateDescription(TypedDict): 158 | """Test state and expected service call count.""" 159 | 160 | state: str | None 161 | attributes: dict 162 | 163 | 164 | class StateDescription(TypedDict): 165 | """Test state and expected service call count.""" 166 | 167 | included: _StateDescription 168 | excluded: _StateDescription 169 | count: int 170 | 171 | 172 | def parametrize_trigger_states( 173 | *, 174 | trigger: str, 175 | target_states: list[str | None | tuple[str | None, dict]], 176 | other_states: list[str | None | tuple[str | None, dict]], 177 | additional_attributes: dict | None = None, 178 | trigger_from_none: bool = True, 179 | ) -> list[tuple[str, list[StateDescription]]]: 180 | """Parametrize states and expected service call counts. 181 | 182 | The target_states and other_states iterables are either iterables of 183 | states or iterables of (state, attributes) tuples. 184 | 185 | Set `trigger_from_none` to False if the trigger is not expected to fire 186 | when the initial state is None. 187 | 188 | Returns a list of tuples with (trigger, list of states), 189 | where states is a list of StateDescription dicts. 190 | """ 191 | 192 | additional_attributes = additional_attributes or {} 193 | 194 | def state_with_attributes( 195 | state: str | None | tuple[str | None, dict], count: int 196 | ) -> dict: 197 | """Return (state, attributes) dict.""" 198 | if isinstance(state, str) or state is None: 199 | return { 200 | "included": { 201 | "state": state, 202 | "attributes": additional_attributes, 203 | }, 204 | "excluded": { 205 | "state": state, 206 | "attributes": {}, 207 | }, 208 | "count": count, 209 | } 210 | return { 211 | "included": { 212 | "state": state[0], 213 | "attributes": state[1] | additional_attributes, 214 | }, 215 | "excluded": { 216 | "state": state[0], 217 | "attributes": state[1], 218 | }, 219 | "count": count, 220 | } 221 | 222 | return [ 223 | # Initial state None 224 | ( 225 | trigger, 226 | list( 227 | itertools.chain.from_iterable( 228 | ( 229 | state_with_attributes(None, 0), 230 | state_with_attributes(target_state, 0), 231 | state_with_attributes(other_state, 0), 232 | state_with_attributes( 233 | target_state, 1 if trigger_from_none else 0 234 | ), 235 | ) 236 | for target_state in target_states 237 | for other_state in other_states 238 | ) 239 | ), 240 | ), 241 | # Initial state different from target state 242 | ( 243 | trigger, 244 | # other_state, 245 | list( 246 | itertools.chain.from_iterable( 247 | ( 248 | state_with_attributes(other_state, 0), 249 | state_with_attributes(target_state, 1), 250 | state_with_attributes(other_state, 0), 251 | state_with_attributes(target_state, 1), 252 | ) 253 | for target_state in target_states 254 | for other_state in other_states 255 | ) 256 | ), 257 | ), 258 | # Initial state same as target state 259 | ( 260 | trigger, 261 | list( 262 | itertools.chain.from_iterable( 263 | ( 264 | state_with_attributes(target_state, 0), 265 | state_with_attributes(target_state, 0), 266 | state_with_attributes(other_state, 0), 267 | state_with_attributes(target_state, 1), 268 | ) 269 | for target_state in target_states 270 | for other_state in other_states 271 | ) 272 | ), 273 | ), 274 | # Initial state unavailable / unknown 275 | ( 276 | trigger, 277 | list( 278 | itertools.chain.from_iterable( 279 | ( 280 | state_with_attributes(STATE_UNAVAILABLE, 0), 281 | state_with_attributes(target_state, 0), 282 | state_with_attributes(other_state, 0), 283 | state_with_attributes(target_state, 1), 284 | ) 285 | for target_state in target_states 286 | for other_state in other_states 287 | ) 288 | ), 289 | ), 290 | ( 291 | trigger, 292 | list( 293 | itertools.chain.from_iterable( 294 | ( 295 | state_with_attributes(STATE_UNKNOWN, 0), 296 | state_with_attributes(target_state, 0), 297 | state_with_attributes(other_state, 0), 298 | state_with_attributes(target_state, 1), 299 | ) 300 | for target_state in target_states 301 | for other_state in other_states 302 | ) 303 | ), 304 | ), 305 | ] 306 | 307 | 308 | async def arm_trigger( 309 | hass: HomeAssistant, 310 | trigger: str, 311 | trigger_options: dict | None, 312 | trigger_target: dict, 313 | ) -> None: 314 | """Arm the specified trigger, call service test.automation when it triggers.""" 315 | 316 | # Local include to avoid importing the automation component unnecessarily 317 | from homeassistant.components import automation # noqa: PLC0415 318 | 319 | options = {CONF_OPTIONS: {**trigger_options}} if trigger_options is not None else {} 320 | 321 | await async_setup_component( 322 | hass, 323 | automation.DOMAIN, 324 | { 325 | automation.DOMAIN: { 326 | "trigger": { 327 | CONF_PLATFORM: trigger, 328 | CONF_TARGET: {**trigger_target}, 329 | } 330 | | options, 331 | "action": { 332 | "service": "test.automation", 333 | "data_template": {CONF_ENTITY_ID: "{{ trigger.entity_id }}"}, 334 | }, 335 | } 336 | }, 337 | ) 338 | 339 | 340 | def set_or_remove_state( 341 | hass: HomeAssistant, 342 | entity_id: str, 343 | state: StateDescription, 344 | ) -> None: 345 | """Set or remove the state of an entity.""" 346 | if state["state"] is None: 347 | hass.states.async_remove(entity_id) 348 | else: 349 | hass.states.async_set( 350 | entity_id, state["state"], state["attributes"], force_update=True 351 | ) 352 | 353 | 354 | def other_states(state: StrEnum) -> list[str]: 355 | """Return a sorted list with all states except the specified one.""" 356 | return sorted({s.value for s in state.__class__} - {state.value}) 357 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/test_util/aiohttp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Aiohttp test utils. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | import asyncio 8 | from collections.abc import Iterator 9 | from contextlib import contextmanager 10 | from http import HTTPStatus 11 | import re 12 | from types import TracebackType 13 | from typing import Any 14 | from unittest import mock 15 | from urllib.parse import parse_qs 16 | 17 | from aiohttp import ClientSession 18 | from aiohttp.client_exceptions import ( 19 | ClientConnectionError, 20 | ClientError, 21 | ClientResponseError, 22 | ) 23 | from aiohttp.streams import StreamReader 24 | from multidict import CIMultiDict 25 | from yarl import URL 26 | 27 | from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE 28 | from homeassistant.core import HomeAssistant 29 | from homeassistant.helpers.json import json_dumps 30 | from homeassistant.util.json import json_loads 31 | 32 | RETYPE = type(re.compile("")) 33 | 34 | 35 | def mock_stream(data): 36 | """Mock a stream with data.""" 37 | protocol = mock.Mock(_reading_paused=False) 38 | stream = StreamReader(protocol, limit=2**16) 39 | stream.feed_data(data) 40 | stream.feed_eof() 41 | return stream 42 | 43 | 44 | class AiohttpClientMocker: 45 | """Mock Aiohttp client requests.""" 46 | 47 | def __init__(self) -> None: 48 | """Initialize the request mocker.""" 49 | self._mocks = [] 50 | self._cookies = {} 51 | self.mock_calls = [] 52 | 53 | def request( 54 | self, 55 | method, 56 | url, 57 | *, 58 | auth=None, 59 | status=HTTPStatus.OK, 60 | text=None, 61 | data=None, 62 | content=None, 63 | json=None, 64 | params=None, 65 | headers=None, 66 | exc=None, 67 | cookies=None, 68 | side_effect=None, 69 | closing=None, 70 | timeout=None, 71 | ): 72 | """Mock a request.""" 73 | if not isinstance(url, RETYPE): 74 | url = URL(url) 75 | if params: 76 | url = url.with_query(params) 77 | 78 | resp = AiohttpClientMockResponse( 79 | method=method, 80 | url=url, 81 | status=status, 82 | response=content, 83 | json=json, 84 | text=text, 85 | cookies=cookies, 86 | exc=exc, 87 | headers=headers, 88 | side_effect=side_effect, 89 | closing=closing, 90 | ) 91 | self._mocks.append(resp) 92 | return resp 93 | 94 | def get(self, *args, **kwargs): 95 | """Register a mock get request.""" 96 | self.request("get", *args, **kwargs) 97 | 98 | def put(self, *args, **kwargs): 99 | """Register a mock put request.""" 100 | self.request("put", *args, **kwargs) 101 | 102 | def post(self, *args, **kwargs): 103 | """Register a mock post request.""" 104 | self.request("post", *args, **kwargs) 105 | 106 | def delete(self, *args, **kwargs): 107 | """Register a mock delete request.""" 108 | self.request("delete", *args, **kwargs) 109 | 110 | def options(self, *args, **kwargs): 111 | """Register a mock options request.""" 112 | self.request("options", *args, **kwargs) 113 | 114 | def patch(self, *args, **kwargs): 115 | """Register a mock patch request.""" 116 | self.request("patch", *args, **kwargs) 117 | 118 | def head(self, *args, **kwargs): 119 | """Register a mock head request.""" 120 | self.request("head", *args, **kwargs) 121 | 122 | @property 123 | def call_count(self): 124 | """Return the number of requests made.""" 125 | return len(self.mock_calls) 126 | 127 | def clear_requests(self): 128 | """Reset mock calls.""" 129 | self._mocks.clear() 130 | self._cookies.clear() 131 | self.mock_calls.clear() 132 | 133 | def create_session(self, loop): 134 | """Create a ClientSession that is bound to this mocker.""" 135 | session = ClientSession(loop=loop, json_serialize=json_dumps) 136 | # Setting directly on `session` will raise deprecation warning 137 | object.__setattr__(session, "_request", self.match_request) 138 | return session 139 | 140 | async def match_request( 141 | self, 142 | method, 143 | url, 144 | *, 145 | data=None, 146 | auth=None, 147 | params=None, 148 | headers=None, 149 | allow_redirects=None, 150 | timeout=None, 151 | json=None, 152 | cookies=None, 153 | **kwargs, 154 | ): 155 | """Match a request against pre-registered requests.""" 156 | data = data or json 157 | url = URL(url) 158 | if params: 159 | url = url.with_query(params) 160 | 161 | for response in self._mocks: 162 | if response.match_request(method, url, params): 163 | # If auth is provided, try to encode it to trigger any encoding errors 164 | if auth is not None: 165 | auth.encode() 166 | self.mock_calls.append((method, url, data, headers)) 167 | if response.side_effect: 168 | response = await response.side_effect(method, url, data) 169 | if response.exc: 170 | raise response.exc 171 | return response 172 | 173 | raise AssertionError(f"No mock registered for {method.upper()} {url} {params}") 174 | 175 | 176 | class AiohttpClientMockResponse: 177 | """Mock Aiohttp client response.""" 178 | 179 | def __init__( 180 | self, 181 | method, 182 | url: URL, 183 | status=HTTPStatus.OK, 184 | response=None, 185 | json=None, 186 | text=None, 187 | cookies=None, 188 | exc=None, 189 | headers=None, 190 | side_effect=None, 191 | closing=None, 192 | ) -> None: 193 | """Initialize a fake response.""" 194 | if json is not None: 195 | text = json_dumps(json) 196 | if text is not None: 197 | response = text.encode("utf-8") 198 | if response is None: 199 | response = b"" 200 | 201 | self.method = method 202 | self._url = url 203 | self.status = status 204 | self._response = response 205 | self.exc = exc 206 | self.side_effect = side_effect 207 | self.closing = closing 208 | self._headers = CIMultiDict(headers or {}) 209 | self._cookies = {} 210 | 211 | if cookies: 212 | for name, data in cookies.items(): 213 | cookie = mock.MagicMock() 214 | cookie.value = data 215 | self._cookies[name] = cookie 216 | 217 | def match_request(self, method, url, params=None): 218 | """Test if response answers request.""" 219 | if method.lower() != self.method.lower(): 220 | return False 221 | 222 | # regular expression matching 223 | if isinstance(self._url, RETYPE): 224 | return self._url.search(str(url)) is not None 225 | 226 | if ( 227 | self._url.scheme != url.scheme 228 | or self._url.raw_host != url.raw_host 229 | or self._url.raw_path != url.raw_path 230 | ): 231 | return False 232 | 233 | # Ensure all query components in matcher are present in the request 234 | request_qs = parse_qs(url.query_string) 235 | matcher_qs = parse_qs(self._url.query_string) 236 | for key, vals in matcher_qs.items(): 237 | for val in vals: 238 | try: 239 | request_qs.get(key, []).remove(val) 240 | except ValueError: 241 | return False 242 | 243 | return True 244 | 245 | @property 246 | def headers(self): 247 | """Return content_type.""" 248 | return self._headers 249 | 250 | @property 251 | def cookies(self): 252 | """Return dict of cookies.""" 253 | return self._cookies 254 | 255 | @property 256 | def url(self): 257 | """Return yarl of URL.""" 258 | return self._url 259 | 260 | @property 261 | def content_type(self): 262 | """Return yarl of URL.""" 263 | return self._headers.get("content-type") 264 | 265 | @property 266 | def content(self): 267 | """Return content.""" 268 | return mock_stream(self.response) 269 | 270 | @property 271 | def charset(self): 272 | """Return charset from Content-Type header.""" 273 | if (content_type := self._headers.get("content-type")) is None: 274 | return None 275 | content_type = content_type.lower() 276 | if "charset=" in content_type: 277 | return content_type.split("charset=")[1].split(";")[0].strip() 278 | return None 279 | 280 | async def read(self): 281 | """Return mock response.""" 282 | return self.response 283 | 284 | async def text(self, encoding=None, errors="strict") -> str: 285 | """Return mock response as a string.""" 286 | # Match real aiohttp behavior: encoding=None means auto-detect 287 | if encoding is None: 288 | encoding = self.charset or "utf-8" 289 | return self.response.decode(encoding, errors=errors) 290 | 291 | async def json(self, encoding=None, content_type=None, loads=json_loads) -> Any: 292 | """Return mock response as a json.""" 293 | # Match real aiohttp behavior: encoding=None means auto-detect 294 | if encoding is None: 295 | encoding = self.charset or "utf-8" 296 | return loads(self.response.decode(encoding)) 297 | 298 | def release(self): 299 | """Mock release.""" 300 | 301 | def raise_for_status(self): 302 | """Raise error if status is 400 or higher.""" 303 | if self.status >= 400: 304 | request_info = mock.Mock(real_url="http://example.com") 305 | raise ClientResponseError( 306 | request_info=request_info, 307 | history=None, 308 | status=self.status, 309 | headers=self.headers, 310 | ) 311 | 312 | def close(self): 313 | """Mock close.""" 314 | 315 | async def wait_for_close(self): 316 | """Wait until all requests are done. 317 | 318 | Do nothing as we are mocking. 319 | """ 320 | 321 | @property 322 | def response(self): 323 | """Property method to expose the response to other read methods.""" 324 | if self.closing: 325 | raise ClientConnectionError("Connection closed") 326 | return self._response 327 | 328 | async def __aenter__(self): 329 | """Enter the context manager.""" 330 | return self 331 | 332 | async def __aexit__( 333 | self, 334 | exc_type: type[BaseException] | None, 335 | exc_val: BaseException | None, 336 | exc_tb: TracebackType | None, 337 | ) -> None: 338 | """Exit the context manager.""" 339 | 340 | 341 | @contextmanager 342 | def mock_aiohttp_client() -> Iterator[AiohttpClientMocker]: 343 | """Context manager to mock aiohttp client.""" 344 | mocker = AiohttpClientMocker() 345 | 346 | def create_session(hass: HomeAssistant, *args: Any, **kwargs: Any) -> ClientSession: 347 | session = mocker.create_session(hass.loop) 348 | 349 | async def close_session(event): 350 | """Close session.""" 351 | await session.close() 352 | 353 | hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, close_session) 354 | 355 | return session 356 | 357 | with mock.patch( 358 | "homeassistant.helpers.aiohttp_client._async_create_clientsession", 359 | side_effect=create_session, 360 | ): 361 | yield mocker 362 | 363 | 364 | class MockLongPollSideEffect: 365 | """Imitate a long_poll request. 366 | 367 | It should be created and used as a side effect for a GET/PUT/etc. request. 368 | Once created, actual responses are queued with queue_response 369 | If queue is empty, will await until done. 370 | """ 371 | 372 | def __init__(self) -> None: 373 | """Initialize the queue.""" 374 | self.semaphore = asyncio.Semaphore(0) 375 | self.response_list = [] 376 | self.stopping = False 377 | 378 | async def __call__(self, method, url, data): 379 | """Fetch the next response from the queue or wait until the queue has items.""" 380 | if self.stopping: 381 | raise ClientError 382 | await self.semaphore.acquire() 383 | kwargs = self.response_list.pop(0) 384 | return AiohttpClientMockResponse(method=method, url=url, **kwargs) 385 | 386 | def queue_response(self, **kwargs): 387 | """Add a response to the long_poll queue.""" 388 | self.response_list.append(kwargs) 389 | self.semaphore.release() 390 | 391 | def stop(self): 392 | """Stop the current request and future ones. 393 | 394 | This avoids an exception if there is someone waiting when exiting test. 395 | """ 396 | self.stopping = True 397 | self.queue_response(exc=ClientError()) 398 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/syrupy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Home Assistant extension for Syrupy. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from contextlib import suppress 10 | import dataclasses 11 | from enum import IntFlag 12 | import json 13 | import os 14 | from pathlib import Path 15 | from typing import Any 16 | 17 | import attr 18 | import attrs 19 | import pytest 20 | from syrupy.constants import EXIT_STATUS_FAIL_UNUSED 21 | from syrupy.data import Snapshot, SnapshotCollection, SnapshotCollections 22 | from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension 23 | from syrupy.location import PyTestLocation 24 | from syrupy.report import SnapshotReport 25 | from syrupy.session import ItemStatus, SnapshotSession 26 | from syrupy.types import PropertyFilter, PropertyMatcher, PropertyPath, SerializableData 27 | from syrupy.utils import is_xdist_controller, is_xdist_worker 28 | import voluptuous as vol 29 | import voluptuous_serialize 30 | 31 | from homeassistant.config_entries import ConfigEntry 32 | from homeassistant.core import State 33 | from homeassistant.data_entry_flow import FlowResult 34 | from homeassistant.helpers import ( 35 | area_registry as ar, 36 | device_registry as dr, 37 | entity_registry as er, 38 | issue_registry as ir, 39 | ) 40 | 41 | 42 | class _ANY: 43 | """Represent any value.""" 44 | 45 | def __repr__(self) -> str: 46 | return "" 47 | 48 | 49 | ANY = _ANY() 50 | 51 | __all__ = ["HomeAssistantSnapshotExtension"] 52 | 53 | 54 | class AreaRegistryEntrySnapshot(dict): 55 | """Tiny wrapper to represent an area registry entry in snapshots.""" 56 | 57 | 58 | class ConfigEntrySnapshot(dict): 59 | """Tiny wrapper to represent a config entry in snapshots.""" 60 | 61 | 62 | class DeviceRegistryEntrySnapshot(dict): 63 | """Tiny wrapper to represent a device registry entry in snapshots.""" 64 | 65 | 66 | class EntityRegistryEntrySnapshot(dict): 67 | """Tiny wrapper to represent an entity registry entry in snapshots.""" 68 | 69 | 70 | class FlowResultSnapshot(dict): 71 | """Tiny wrapper to represent a flow result in snapshots.""" 72 | 73 | 74 | class IssueRegistryItemSnapshot(dict): 75 | """Tiny wrapper to represent an entity registry entry in snapshots.""" 76 | 77 | 78 | class StateSnapshot(dict): 79 | """Tiny wrapper to represent an entity state in snapshots.""" 80 | 81 | 82 | class HomeAssistantSnapshotSerializer(AmberDataSerializer): 83 | """Home Assistant snapshot serializer for Syrupy. 84 | 85 | Handles special cases for Home Assistant data structures. 86 | """ 87 | 88 | @classmethod 89 | def _serialize( 90 | cls, 91 | data: SerializableData, 92 | *, 93 | depth: int = 0, 94 | exclude: PropertyFilter | None = None, 95 | include: PropertyFilter | None = None, 96 | matcher: PropertyMatcher | None = None, 97 | path: PropertyPath = (), 98 | visited: set[Any] | None = None, 99 | ) -> str: 100 | """Pre-process data before serializing. 101 | 102 | This allows us to handle specific cases for Home Assistant data structures. 103 | """ 104 | if isinstance(data, State): 105 | serializable_data = cls._serializable_state(data) 106 | elif isinstance(data, ar.AreaEntry): 107 | serializable_data = cls._serializable_area_registry_entry(data) 108 | elif isinstance(data, dr.DeviceEntry): 109 | serializable_data = cls._serializable_device_registry_entry(data) 110 | elif isinstance(data, er.RegistryEntry): 111 | serializable_data = cls._serializable_entity_registry_entry(data) 112 | elif isinstance(data, ir.IssueEntry): 113 | serializable_data = cls._serializable_issue_registry_entry(data) 114 | elif isinstance(data, dict) and "flow_id" in data and "handler" in data: 115 | serializable_data = cls._serializable_flow_result(data) 116 | elif isinstance(data, dict) and set(data) == { 117 | "conversation_id", 118 | "response", 119 | "continue_conversation", 120 | }: 121 | serializable_data = cls._serializable_conversation_result(data) 122 | elif isinstance(data, vol.Schema): 123 | serializable_data = voluptuous_serialize.convert(data) 124 | elif isinstance(data, ConfigEntry): 125 | serializable_data = cls._serializable_config_entry(data) 126 | elif dataclasses.is_dataclass(type(data)): 127 | serializable_data = dataclasses.asdict(data) 128 | elif isinstance(data, IntFlag): 129 | # The repr of an enum.IntFlag has changed between Python 3.10 and 3.11 130 | # so we normalize it here. 131 | serializable_data = _IntFlagWrapper(data) 132 | else: 133 | serializable_data = data 134 | with suppress(TypeError): 135 | if attr.has(type(data)): 136 | serializable_data = attrs.asdict(data) 137 | 138 | return super()._serialize( 139 | serializable_data, 140 | depth=depth, 141 | exclude=exclude, 142 | include=include, 143 | matcher=matcher, 144 | path=path, 145 | visited=visited, 146 | ) 147 | 148 | @classmethod 149 | def _serializable_area_registry_entry(cls, data: ar.AreaEntry) -> SerializableData: 150 | """Prepare a Home Assistant area registry entry for serialization.""" 151 | serialized = AreaRegistryEntrySnapshot(dataclasses.asdict(data) | {"id": ANY}) 152 | serialized.pop("_json_repr") 153 | serialized.pop("_cache") 154 | return serialized 155 | 156 | @classmethod 157 | def _serializable_config_entry(cls, data: ConfigEntry) -> SerializableData: 158 | """Prepare a Home Assistant config entry for serialization.""" 159 | entry = ConfigEntrySnapshot(data.as_dict() | {"entry_id": ANY}) 160 | return cls._remove_created_and_modified_at(entry) 161 | 162 | @classmethod 163 | def _serializable_device_registry_entry( 164 | cls, data: dr.DeviceEntry 165 | ) -> SerializableData: 166 | """Prepare a Home Assistant device registry entry for serialization.""" 167 | serialized = DeviceRegistryEntrySnapshot( 168 | attrs.asdict(data) 169 | | { 170 | "config_entries": ANY, 171 | "config_entries_subentries": ANY, 172 | "id": ANY, 173 | } 174 | ) 175 | if serialized["via_device_id"] is not None: 176 | serialized["via_device_id"] = ANY 177 | if serialized["primary_config_entry"] is not None: 178 | serialized["primary_config_entry"] = ANY 179 | serialized.pop("_cache") 180 | # This can be removed when suggested_area is removed from DeviceEntry 181 | serialized.pop("_suggested_area") 182 | return cls._remove_created_and_modified_at(serialized) 183 | 184 | @classmethod 185 | def _remove_created_and_modified_at( 186 | cls, data: SerializableData 187 | ) -> SerializableData: 188 | """Remove created_at and modified_at from the data.""" 189 | data.pop("created_at", None) 190 | data.pop("modified_at", None) 191 | return data 192 | 193 | @classmethod 194 | def _serializable_entity_registry_entry( 195 | cls, data: er.RegistryEntry 196 | ) -> SerializableData: 197 | """Prepare a Home Assistant entity registry entry for serialization.""" 198 | serialized = EntityRegistryEntrySnapshot( 199 | attrs.asdict(data) 200 | | { 201 | "config_entry_id": ANY, 202 | "config_subentry_id": ANY, 203 | "device_id": ANY, 204 | "id": ANY, 205 | "options": {k: dict(v) for k, v in data.options.items()}, 206 | } 207 | ) 208 | serialized.pop("categories") 209 | serialized.pop("_cache") 210 | return cls._remove_created_and_modified_at(serialized) 211 | 212 | @classmethod 213 | def _serializable_flow_result(cls, data: FlowResult) -> SerializableData: 214 | """Prepare a Home Assistant flow result for serialization.""" 215 | return FlowResultSnapshot(data | {"flow_id": ANY}) 216 | 217 | @classmethod 218 | def _serializable_conversation_result(cls, data: dict) -> SerializableData: 219 | """Prepare a Home Assistant conversation result for serialization.""" 220 | return data | {"conversation_id": ANY} 221 | 222 | @classmethod 223 | def _serializable_issue_registry_entry( 224 | cls, data: ir.IssueEntry 225 | ) -> SerializableData: 226 | """Prepare a Home Assistant issue registry entry for serialization.""" 227 | return IssueRegistryItemSnapshot(dataclasses.asdict(data) | {"created": ANY}) 228 | 229 | @classmethod 230 | def _serializable_state(cls, data: State) -> SerializableData: 231 | """Prepare a Home Assistant State for serialization.""" 232 | return StateSnapshot( 233 | data.as_dict() 234 | | { 235 | "context": ANY, 236 | "last_changed": ANY, 237 | "last_reported": ANY, 238 | "last_updated": ANY, 239 | } 240 | ) 241 | 242 | 243 | class _IntFlagWrapper: 244 | def __init__(self, flag: IntFlag) -> None: 245 | self._flag = flag 246 | 247 | def __repr__(self) -> str: 248 | # 3.10: 249 | # 3.11: 250 | # Syrupy: 251 | return f"<{self._flag.__class__.__name__}: {self._flag.value}>" 252 | 253 | 254 | class HomeAssistantSnapshotExtension(AmberSnapshotExtension): 255 | """Home Assistant extension for Syrupy.""" 256 | 257 | VERSION = "1" 258 | """Current version of serialization format. 259 | 260 | Need to be bumped when we change the HomeAssistantSnapshotSerializer. 261 | """ 262 | 263 | serializer_class: type[AmberDataSerializer] = HomeAssistantSnapshotSerializer 264 | 265 | @classmethod 266 | def dirname(cls, *, test_location: PyTestLocation) -> str: 267 | """Return the directory for the snapshot files. 268 | 269 | Syrupy, by default, uses the `__snapshosts__` directory in the same 270 | folder as the test file. For Home Assistant, this is changed to just 271 | `snapshots` in the same folder as the test file, to match our `fixtures` 272 | folder structure. 273 | """ 274 | test_dir = Path(test_location.filepath).parent 275 | return str(test_dir.joinpath("snapshots")) 276 | 277 | 278 | # Classes and Methods to override default finish behavior in syrupy 279 | # This is needed to handle the xdist plugin in pytest 280 | # The default implementation does not handle the xdist plugin 281 | # and will not work correctly when running tests in parallel 282 | # with pytest-xdist. 283 | # Temporary workaround until it is finalised inside syrupy 284 | # See https://github.com/syrupy-project/syrupy/pull/901 285 | 286 | 287 | class _FakePytestObject: 288 | """Fake object.""" 289 | 290 | def __init__(self, collected_item: dict[str, str]) -> None: 291 | """Initialise fake object.""" 292 | self.__module__ = collected_item["modulename"] 293 | self.__name__ = collected_item["methodname"] 294 | 295 | 296 | class _FakePytestItem: 297 | """Fake pytest.Item object.""" 298 | 299 | def __init__(self, collected_item: dict[str, str]) -> None: 300 | """Initialise fake pytest.Item object.""" 301 | self.nodeid = collected_item["nodeid"] 302 | self.name = collected_item["name"] 303 | self.path = Path(collected_item["path"]) 304 | self.obj = _FakePytestObject(collected_item) 305 | 306 | 307 | def _serialize_collections(collections: SnapshotCollections) -> dict[str, Any]: 308 | return { 309 | k: [c.name for c in v] for k, v in collections._snapshot_collections.items() 310 | } 311 | 312 | 313 | def _serialize_report( 314 | report: SnapshotReport, 315 | collected_items: set[pytest.Item], 316 | selected_items: dict[str, ItemStatus], 317 | ) -> dict[str, Any]: 318 | return { 319 | "discovered": _serialize_collections(report.discovered), 320 | "created": _serialize_collections(report.created), 321 | "failed": _serialize_collections(report.failed), 322 | "matched": _serialize_collections(report.matched), 323 | "updated": _serialize_collections(report.updated), 324 | "used": _serialize_collections(report.used), 325 | "_collected_items": [ 326 | { 327 | "nodeid": c.nodeid, 328 | "name": c.name, 329 | "path": str(c.path), 330 | "modulename": c.obj.__module__, 331 | "methodname": c.obj.__name__, 332 | } 333 | for c in list(collected_items) 334 | ], 335 | "_selected_items": { 336 | key: status.value for key, status in selected_items.items() 337 | }, 338 | } 339 | 340 | 341 | def _merge_serialized_collections( 342 | collections: SnapshotCollections, json_data: dict[str, list[str]] 343 | ) -> None: 344 | if not json_data: 345 | return 346 | for location, names in json_data.items(): 347 | snapshot_collection = SnapshotCollection(location=location) 348 | for name in names: 349 | snapshot_collection.add(Snapshot(name)) 350 | collections.update(snapshot_collection) 351 | 352 | 353 | def _merge_serialized_report(report: SnapshotReport, json_data: dict[str, Any]) -> None: 354 | _merge_serialized_collections(report.discovered, json_data["discovered"]) 355 | _merge_serialized_collections(report.created, json_data["created"]) 356 | _merge_serialized_collections(report.failed, json_data["failed"]) 357 | _merge_serialized_collections(report.matched, json_data["matched"]) 358 | _merge_serialized_collections(report.updated, json_data["updated"]) 359 | _merge_serialized_collections(report.used, json_data["used"]) 360 | for collected_item in json_data["_collected_items"]: 361 | custom_item = _FakePytestItem(collected_item) 362 | if not any( 363 | t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid 364 | for t in report.collected_items 365 | ): 366 | report.collected_items.add(custom_item) 367 | for key, selected_item in json_data["_selected_items"].items(): 368 | if key in report.selected_items: 369 | status = ItemStatus(selected_item) 370 | if status != ItemStatus.NOT_RUN: 371 | report.selected_items[key] = status 372 | else: 373 | report.selected_items[key] = ItemStatus(selected_item) 374 | 375 | 376 | def override_syrupy_finish(self: SnapshotSession) -> int: 377 | """Override the finish method to allow for custom handling.""" 378 | exitstatus = 0 379 | self.flush_snapshot_write_queue() 380 | self.report = SnapshotReport( 381 | base_dir=self.pytest_session.config.rootpath, 382 | collected_items=self._collected_items, 383 | selected_items=self._selected_items, 384 | assertions=self._assertions, 385 | options=self.pytest_session.config.option, 386 | ) 387 | 388 | needs_xdist_merge = self.update_snapshots or bool( 389 | self.pytest_session.config.option.include_snapshot_details 390 | ) 391 | 392 | if is_xdist_worker(): 393 | if not needs_xdist_merge: 394 | return exitstatus 395 | with open(".pytest_syrupy_worker_count", "w", encoding="utf-8") as f: 396 | f.write(os.getenv("PYTEST_XDIST_WORKER_COUNT")) 397 | with open( 398 | f".pytest_syrupy_{os.getenv('PYTEST_XDIST_WORKER')}_result", 399 | "w", 400 | encoding="utf-8", 401 | ) as f: 402 | json.dump( 403 | _serialize_report( 404 | self.report, self._collected_items, self._selected_items 405 | ), 406 | f, 407 | indent=2, 408 | ) 409 | return exitstatus 410 | if is_xdist_controller(): 411 | return exitstatus 412 | 413 | if needs_xdist_merge: 414 | worker_count = None 415 | try: 416 | with open(".pytest_syrupy_worker_count", encoding="utf-8") as f: 417 | worker_count = f.read() 418 | os.remove(".pytest_syrupy_worker_count") 419 | except FileNotFoundError: 420 | pass 421 | 422 | if worker_count: 423 | for i in range(int(worker_count)): 424 | with open(f".pytest_syrupy_gw{i}_result", encoding="utf-8") as f: 425 | _merge_serialized_report(self.report, json.load(f)) 426 | os.remove(f".pytest_syrupy_gw{i}_result") 427 | 428 | if self.report.num_unused: 429 | if self.update_snapshots: 430 | self.remove_unused_snapshots( 431 | unused_snapshot_collections=self.report.unused, 432 | used_snapshot_collections=self.report.used, 433 | ) 434 | elif not self.warn_unused_snapshots: 435 | exitstatus |= EXIT_STATUS_FAIL_UNUSED 436 | return exitstatus 437 | -------------------------------------------------------------------------------- /src/pytest_homeassistant_custom_component/components/recorder/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common test utils for working with recorder. 3 | 4 | This file is originally from homeassistant/core and modified by pytest-homeassistant-custom-component. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import asyncio 10 | from collections.abc import Iterable, Iterator 11 | from contextlib import contextmanager 12 | from dataclasses import dataclass 13 | from datetime import datetime, timedelta 14 | from functools import partial 15 | import importlib 16 | import sys 17 | import time 18 | from types import ModuleType 19 | from typing import Any, Literal, cast 20 | from unittest.mock import MagicMock, patch, sentinel 21 | 22 | from freezegun import freeze_time 23 | import pytest 24 | from sqlalchemy import create_engine, event as sqlalchemy_event 25 | from sqlalchemy.orm.session import Session 26 | 27 | from homeassistant import core as ha 28 | from homeassistant.components import recorder 29 | from homeassistant.components.recorder import ( 30 | Recorder, 31 | core, 32 | get_instance, 33 | migration, 34 | statistics, 35 | ) 36 | from homeassistant.components.recorder.db_schema import ( 37 | EventData, 38 | Events, 39 | EventTypes, 40 | RecorderRuns, 41 | StateAttributes, 42 | States, 43 | StatesMeta, 44 | ) 45 | from homeassistant.components.recorder.models import ( 46 | bytes_to_ulid_or_none, 47 | bytes_to_uuid_hex_or_none, 48 | ) 49 | from homeassistant.components.recorder.tasks import RecorderTask, StatisticsTask 50 | from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass 51 | from homeassistant.const import DEGREE, UnitOfTemperature 52 | from homeassistant.core import Event, HomeAssistant, State 53 | from homeassistant.helpers import recorder as recorder_helper 54 | from homeassistant.util import dt as dt_util 55 | from homeassistant.util.json import json_loads, json_loads_object 56 | 57 | from . import db_schema_0 58 | 59 | DEFAULT_PURGE_TASKS = 3 60 | CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" 61 | 62 | 63 | @dataclass 64 | class BlockRecorderTask(RecorderTask): 65 | """A task to block the recorder for testing only.""" 66 | 67 | event: asyncio.Event 68 | seconds: float 69 | 70 | def run(self, instance: Recorder) -> None: 71 | """Block the recorders event loop.""" 72 | instance.hass.loop.call_soon_threadsafe(self.event.set) 73 | time.sleep(self.seconds) 74 | 75 | 76 | @dataclass 77 | class ForceReturnConnectionToPool(RecorderTask): 78 | """Force return connection to pool.""" 79 | 80 | def run(self, instance: Recorder) -> None: 81 | """Handle the task.""" 82 | instance.event_session.commit() 83 | 84 | 85 | async def async_block_recorder(hass: HomeAssistant, seconds: float) -> None: 86 | """Block the recorders event loop for testing. 87 | 88 | Returns as soon as the recorder has started the block. 89 | 90 | Does not wait for the block to finish. 91 | """ 92 | event = asyncio.Event() 93 | get_instance(hass).queue_task(BlockRecorderTask(event, seconds)) 94 | await event.wait() 95 | 96 | 97 | async def async_wait_recorder(hass: HomeAssistant) -> bool: 98 | """Wait for recorder to initialize and return connection status.""" 99 | return await hass.data[recorder_helper.DATA_RECORDER].db_connected 100 | 101 | 102 | def get_start_time(start: datetime) -> datetime: 103 | """Calculate a valid start time for statistics.""" 104 | start_minutes = start.minute - start.minute % 5 105 | return start.replace(minute=start_minutes, second=0, microsecond=0) 106 | 107 | 108 | def do_adhoc_statistics(hass: HomeAssistant, **kwargs: Any) -> None: 109 | """Trigger an adhoc statistics run.""" 110 | if not (start := kwargs.get("start")): 111 | start = statistics.get_start_time() 112 | elif (start.minute % 5) != 0 or start.second != 0 or start.microsecond != 0: 113 | raise ValueError(f"Statistics must start on 5 minute boundary got {start}") 114 | get_instance(hass).queue_task(StatisticsTask(start, False)) 115 | 116 | 117 | def wait_recording_done(hass: HomeAssistant) -> None: 118 | """Block till recording is done.""" 119 | hass.block_till_done() 120 | trigger_db_commit(hass) 121 | hass.block_till_done() 122 | recorder.get_instance(hass).block_till_done() 123 | hass.block_till_done() 124 | 125 | 126 | def trigger_db_commit(hass: HomeAssistant) -> None: 127 | """Force the recorder to commit.""" 128 | recorder.get_instance(hass)._async_commit(dt_util.utcnow()) 129 | 130 | 131 | async def async_wait_recording_done(hass: HomeAssistant) -> None: 132 | """Async wait until recording is done.""" 133 | await hass.async_block_till_done() 134 | async_trigger_db_commit(hass) 135 | await hass.async_block_till_done() 136 | await async_recorder_block_till_done(hass) 137 | await hass.async_block_till_done() 138 | 139 | 140 | async def async_wait_purge_done( 141 | hass: HomeAssistant, max_number: int | None = None 142 | ) -> None: 143 | """Wait for max number of purge events. 144 | 145 | Because a purge may insert another PurgeTask into 146 | the queue after the WaitTask finishes, we need up to 147 | a maximum number of WaitTasks that we will put into the 148 | queue. 149 | """ 150 | if not max_number: 151 | max_number = DEFAULT_PURGE_TASKS 152 | for _ in range(max_number + 1): 153 | await async_wait_recording_done(hass) 154 | 155 | 156 | @ha.callback 157 | def async_trigger_db_commit(hass: HomeAssistant) -> None: 158 | """Force the recorder to commit. Async friendly.""" 159 | recorder.get_instance(hass)._async_commit(dt_util.utcnow()) 160 | 161 | 162 | async def async_recorder_block_till_done(hass: HomeAssistant) -> None: 163 | """Non blocking version of recorder.block_till_done().""" 164 | await hass.async_add_executor_job(recorder.get_instance(hass).block_till_done) 165 | 166 | 167 | def corrupt_db_file(test_db_file): 168 | """Corrupt an sqlite3 database file.""" 169 | with open(test_db_file, "w+", encoding="utf8") as fhandle: 170 | fhandle.seek(200) 171 | fhandle.write("I am a corrupt db" * 100) 172 | 173 | 174 | def create_engine_test(*args, **kwargs): 175 | """Test version of create_engine that initializes with old schema. 176 | 177 | This simulates an existing db with the old schema. 178 | """ 179 | engine = create_engine(*args, **kwargs) 180 | db_schema_0.Base.metadata.create_all(engine) 181 | return engine 182 | 183 | 184 | def run_information_with_session( 185 | session: Session, point_in_time: datetime | None = None 186 | ) -> RecorderRuns | None: 187 | """Return information about current run from the database.""" 188 | recorder_runs = RecorderRuns 189 | 190 | query = session.query(recorder_runs) 191 | if point_in_time: 192 | query = query.filter( 193 | (recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time) 194 | ) 195 | 196 | if (res := query.first()) is not None: 197 | session.expunge(res) 198 | return cast(RecorderRuns, res) 199 | return res 200 | 201 | 202 | def statistics_during_period( 203 | hass: HomeAssistant, 204 | start_time: datetime, 205 | end_time: datetime | None = None, 206 | statistic_ids: set[str] | None = None, 207 | period: Literal["5minute", "day", "hour", "week", "month"] = "hour", 208 | units: dict[str, str] | None = None, 209 | types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]] 210 | | None = None, 211 | ) -> dict[str, list[dict[str, Any]]]: 212 | """Call statistics_during_period with defaults for simpler ...""" 213 | if statistic_ids is not None and not isinstance(statistic_ids, set): 214 | statistic_ids = set(statistic_ids) 215 | if types is None: 216 | types = {"last_reset", "max", "mean", "min", "state", "sum"} 217 | return statistics.statistics_during_period( 218 | hass, start_time, end_time, statistic_ids, period, units, types 219 | ) 220 | 221 | 222 | def assert_states_equal_without_context(state: State, other: State) -> None: 223 | """Assert that two states are equal, ignoring context.""" 224 | assert_states_equal_without_context_and_last_changed(state, other) 225 | assert state.last_changed == other.last_changed 226 | assert state.last_reported == other.last_reported 227 | 228 | 229 | def assert_states_equal_without_context_and_last_changed( 230 | state: State, other: State 231 | ) -> None: 232 | """Assert that two states are equal, ignoring context and last_changed.""" 233 | assert state.state == other.state 234 | assert state.attributes == other.attributes 235 | assert state.last_updated == other.last_updated 236 | 237 | 238 | def assert_multiple_states_equal_without_context_and_last_changed( 239 | states: Iterable[State], others: Iterable[State] 240 | ) -> None: 241 | """Assert that multiple states are equal, ignoring context and last_changed.""" 242 | states_list = list(states) 243 | others_list = list(others) 244 | assert len(states_list) == len(others_list) 245 | for i, state in enumerate(states_list): 246 | assert_states_equal_without_context_and_last_changed(state, others_list[i]) 247 | 248 | 249 | def assert_multiple_states_equal_without_context( 250 | states: Iterable[State], others: Iterable[State] 251 | ) -> None: 252 | """Assert that multiple states are equal, ignoring context.""" 253 | states_list = list(states) 254 | others_list = list(others) 255 | assert len(states_list) == len(others_list) 256 | for i, state in enumerate(states_list): 257 | assert_states_equal_without_context(state, others_list[i]) 258 | 259 | 260 | def assert_events_equal_without_context(event: Event, other: Event) -> None: 261 | """Assert that two events are equal, ignoring context.""" 262 | assert event.data == other.data 263 | assert event.event_type == other.event_type 264 | assert event.origin == other.origin 265 | assert event.time_fired == other.time_fired 266 | 267 | 268 | def assert_dict_of_states_equal_without_context( 269 | states: dict[str, list[State]], others: dict[str, list[State]] 270 | ) -> None: 271 | """Assert that two dicts of states are equal, ignoring context.""" 272 | assert len(states) == len(others) 273 | for entity_id, state in states.items(): 274 | assert_multiple_states_equal_without_context(state, others[entity_id]) 275 | 276 | 277 | def assert_dict_of_states_equal_without_context_and_last_changed( 278 | states: dict[str, list[State]], others: dict[str, list[State]] 279 | ) -> None: 280 | """Assert that two dicts of states are equal, ignoring context and last_changed.""" 281 | assert len(states) == len(others) 282 | for entity_id, state in states.items(): 283 | assert_multiple_states_equal_without_context_and_last_changed( 284 | state, others[entity_id] 285 | ) 286 | 287 | 288 | async def async_record_states( 289 | hass: HomeAssistant, 290 | ) -> tuple[datetime, datetime, dict[str, list[State | None]]]: 291 | """Record some test states.""" 292 | return await hass.async_add_executor_job(record_states, hass) 293 | 294 | 295 | def record_states( 296 | hass: HomeAssistant, 297 | ) -> tuple[datetime, datetime, dict[str, list[State | None]]]: 298 | """Record some test states. 299 | 300 | We inject a bunch of state updates temperature sensors. 301 | """ 302 | mp = "media_player.test" 303 | sns1 = "sensor.test1" 304 | sns2 = "sensor.test2" 305 | sns3 = "sensor.test3" 306 | sns4 = "sensor.test4" 307 | sns5 = "sensor.wind_direction" 308 | sns1_attr = { 309 | "device_class": "temperature", 310 | "state_class": "measurement", 311 | "unit_of_measurement": UnitOfTemperature.CELSIUS, 312 | } 313 | sns2_attr = { 314 | "device_class": "humidity", 315 | "state_class": "measurement", 316 | "unit_of_measurement": "%", 317 | } 318 | sns3_attr = {"device_class": "temperature"} 319 | sns4_attr = {} 320 | sns5_attr = { 321 | "device_class": SensorDeviceClass.WIND_DIRECTION, 322 | "state_class": SensorStateClass.MEASUREMENT_ANGLE, 323 | "unit_of_measurement": DEGREE, 324 | } 325 | 326 | def set_state(entity_id, state, **kwargs): 327 | """Set the state.""" 328 | hass.states.set(entity_id, state, **kwargs) 329 | wait_recording_done(hass) 330 | return hass.states.get(entity_id) 331 | 332 | zero = get_start_time(dt_util.utcnow()) 333 | one = zero + timedelta(seconds=1 * 5) 334 | two = one + timedelta(seconds=15 * 5) 335 | three = two + timedelta(seconds=30 * 5) 336 | four = three + timedelta(seconds=14 * 5) 337 | 338 | states = {mp: [], sns1: [], sns2: [], sns3: [], sns4: [], sns5: []} 339 | with freeze_time(one) as freezer: 340 | states[mp].append( 341 | set_state(mp, "idle", attributes={"media_title": str(sentinel.mt1)}) 342 | ) 343 | states[sns1].append(set_state(sns1, "10", attributes=sns1_attr)) 344 | states[sns2].append(set_state(sns2, "10", attributes=sns2_attr)) 345 | states[sns3].append(set_state(sns3, "10", attributes=sns3_attr)) 346 | states[sns4].append(set_state(sns4, "10", attributes=sns4_attr)) 347 | states[sns5].append(set_state(sns5, "10", attributes=sns5_attr)) 348 | 349 | freezer.move_to(one + timedelta(microseconds=1)) 350 | states[mp].append( 351 | set_state(mp, "YouTube", attributes={"media_title": str(sentinel.mt2)}) 352 | ) 353 | 354 | freezer.move_to(two) 355 | states[sns1].append(set_state(sns1, "15", attributes=sns1_attr)) 356 | states[sns2].append(set_state(sns2, "15", attributes=sns2_attr)) 357 | states[sns3].append(set_state(sns3, "15", attributes=sns3_attr)) 358 | states[sns4].append(set_state(sns4, "15", attributes=sns4_attr)) 359 | states[sns5].append(set_state(sns5, "350", attributes=sns5_attr)) 360 | 361 | freezer.move_to(three) 362 | states[sns1].append(set_state(sns1, "20", attributes=sns1_attr)) 363 | states[sns2].append(set_state(sns2, "20", attributes=sns2_attr)) 364 | states[sns3].append(set_state(sns3, "20", attributes=sns3_attr)) 365 | states[sns4].append(set_state(sns4, "20", attributes=sns4_attr)) 366 | states[sns5].append(set_state(sns5, "5", attributes=sns5_attr)) 367 | 368 | return zero, four, states 369 | 370 | 371 | def convert_pending_states_to_meta(instance: Recorder, session: Session) -> None: 372 | """Convert pending states to use states_metadata.""" 373 | entity_ids: set[str] = set() 374 | states: set[States] = set() 375 | states_meta_objects: dict[str, StatesMeta] = {} 376 | for session_object in session: 377 | if isinstance(session_object, States): 378 | entity_ids.add(session_object.entity_id) 379 | states.add(session_object) 380 | 381 | entity_id_to_metadata_ids = instance.states_meta_manager.get_many( 382 | entity_ids, session, True 383 | ) 384 | 385 | for state in states: 386 | entity_id = state.entity_id 387 | state.entity_id = None 388 | state.attributes = None 389 | state.event_id = None 390 | if metadata_id := entity_id_to_metadata_ids.get(entity_id): 391 | state.metadata_id = metadata_id 392 | continue 393 | if entity_id not in states_meta_objects: 394 | states_meta_objects[entity_id] = StatesMeta(entity_id=entity_id) 395 | state.states_meta_rel = states_meta_objects[entity_id] 396 | 397 | 398 | def convert_pending_events_to_event_types(instance: Recorder, session: Session) -> None: 399 | """Convert pending events to use event_type_ids.""" 400 | event_types: set[str] = set() 401 | events: set[Events] = set() 402 | event_types_objects: dict[str, EventTypes] = {} 403 | for session_object in session: 404 | if isinstance(session_object, Events): 405 | event_types.add(session_object.event_type) 406 | events.add(session_object) 407 | 408 | event_type_to_event_type_ids = instance.event_type_manager.get_many( 409 | event_types, session, True 410 | ) 411 | manually_added_event_types: list[str] = [] 412 | 413 | for event in events: 414 | event_type = event.event_type 415 | event.event_type = None 416 | event.event_data = None 417 | event.origin = None 418 | if event_type_id := event_type_to_event_type_ids.get(event_type): 419 | event.event_type_id = event_type_id 420 | continue 421 | if event_type not in event_types_objects: 422 | event_types_objects[event_type] = EventTypes(event_type=event_type) 423 | manually_added_event_types.append(event_type) 424 | event.event_type_rel = event_types_objects[event_type] 425 | 426 | for event_type in manually_added_event_types: 427 | instance.event_type_manager._non_existent_event_types.pop(event_type, None) 428 | 429 | 430 | def create_engine_test_for_schema_version_postfix( 431 | *args, schema_version_postfix: str, **kwargs 432 | ): 433 | """Test version of create_engine that initializes with old schema. 434 | 435 | This simulates an existing db with the old schema. 436 | """ 437 | schema_module = get_schema_module_path(schema_version_postfix) 438 | importlib.import_module(schema_module) 439 | old_db_schema = sys.modules[schema_module] 440 | instance: Recorder | None = None 441 | if "hass" in kwargs: 442 | hass: HomeAssistant = kwargs.pop("hass") 443 | instance = recorder.get_instance(hass) 444 | engine = create_engine(*args, **kwargs) 445 | if instance is not None: 446 | instance = recorder.get_instance(hass) 447 | instance.engine = engine 448 | sqlalchemy_event.listen(engine, "connect", instance._setup_recorder_connection) 449 | old_db_schema.Base.metadata.create_all(engine) 450 | with Session(engine) as session: 451 | session.add( 452 | recorder.db_schema.StatisticsRuns(start=statistics.get_start_time()) 453 | ) 454 | session.add( 455 | recorder.db_schema.SchemaChanges( 456 | schema_version=old_db_schema.SCHEMA_VERSION 457 | ) 458 | ) 459 | session.commit() 460 | return engine 461 | 462 | 463 | def get_schema_module_path(schema_version_postfix: str) -> str: 464 | """Return the path to the schema module.""" 465 | return f"...components.recorder.db_schema_{schema_version_postfix}" 466 | 467 | 468 | def get_patched_live_version(old_db_schema: ModuleType) -> int: 469 | """Return the patched live migration version.""" 470 | return min( 471 | migration.LIVE_MIGRATION_MIN_SCHEMA_VERSION, old_db_schema.SCHEMA_VERSION 472 | ) 473 | 474 | 475 | @contextmanager 476 | def old_db_schema(hass: HomeAssistant, schema_version_postfix: str) -> Iterator[None]: 477 | """Fixture to initialize the db with the old schema.""" 478 | schema_module = get_schema_module_path(schema_version_postfix) 479 | importlib.import_module(schema_module) 480 | old_db_schema = sys.modules[schema_module] 481 | 482 | with ( 483 | patch.object(recorder, "db_schema", old_db_schema), 484 | patch.object(migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION), 485 | patch.object( 486 | migration, 487 | "LIVE_MIGRATION_MIN_SCHEMA_VERSION", 488 | get_patched_live_version(old_db_schema), 489 | ), 490 | patch.object(migration, "non_live_data_migration_needed", return_value=False), 491 | patch.object(core, "StatesMeta", old_db_schema.StatesMeta), 492 | patch.object(core, "EventTypes", old_db_schema.EventTypes), 493 | patch.object(core, "EventData", old_db_schema.EventData), 494 | patch.object(core, "States", old_db_schema.States), 495 | patch.object(core, "Events", old_db_schema.Events), 496 | patch.object(core, "StateAttributes", old_db_schema.StateAttributes), 497 | patch( 498 | CREATE_ENGINE_TARGET, 499 | new=partial( 500 | create_engine_test_for_schema_version_postfix, 501 | hass=hass, 502 | schema_version_postfix=schema_version_postfix, 503 | ), 504 | ), 505 | ): 506 | yield 507 | 508 | 509 | async def async_attach_db_engine(hass: HomeAssistant) -> None: 510 | """Attach a database engine to the recorder.""" 511 | instance = recorder.get_instance(hass) 512 | 513 | def _mock_setup_recorder_connection(): 514 | with instance.engine.connect() as connection: 515 | instance._setup_recorder_connection( 516 | connection._dbapi_connection, MagicMock() 517 | ) 518 | 519 | await instance.async_add_executor_job(_mock_setup_recorder_connection) 520 | 521 | 522 | EVENT_ORIGIN_ORDER = [ha.EventOrigin.local, ha.EventOrigin.remote] 523 | 524 | 525 | def db_event_to_native(event: Events, validate_entity_id: bool = True) -> Event | None: 526 | """Convert to a native HA Event.""" 527 | context = ha.Context( 528 | id=bytes_to_ulid_or_none(event.context_id_bin), 529 | user_id=bytes_to_uuid_hex_or_none(event.context_user_id_bin), 530 | parent_id=bytes_to_ulid_or_none(event.context_parent_id_bin), 531 | ) 532 | return Event( 533 | event.event_type or "", 534 | json_loads_object(event.event_data) if event.event_data else {}, 535 | ha.EventOrigin(event.origin) 536 | if event.origin 537 | else EVENT_ORIGIN_ORDER[event.origin_idx or 0], 538 | event.time_fired_ts or 0, 539 | context=context, 540 | ) 541 | 542 | 543 | def db_event_data_to_native(event_data: EventData) -> dict[str, Any]: 544 | """Convert to an event data dictionary.""" 545 | shared_data = event_data.shared_data 546 | if shared_data is None: 547 | return {} 548 | return cast(dict[str, Any], json_loads(shared_data)) 549 | 550 | 551 | def db_state_to_native(state: States, validate_entity_id: bool = True) -> State | None: 552 | """Convert to an HA state object.""" 553 | context = ha.Context( 554 | id=bytes_to_ulid_or_none(state.context_id_bin), 555 | user_id=bytes_to_uuid_hex_or_none(state.context_user_id_bin), 556 | parent_id=bytes_to_ulid_or_none(state.context_parent_id_bin), 557 | ) 558 | attrs = json_loads_object(state.attributes) if state.attributes else {} 559 | last_updated = dt_util.utc_from_timestamp(state.last_updated_ts or 0) 560 | if state.last_changed_ts is None or state.last_changed_ts == state.last_updated_ts: 561 | last_changed = dt_util.utc_from_timestamp(state.last_updated_ts or 0) 562 | else: 563 | last_changed = dt_util.utc_from_timestamp(state.last_changed_ts or 0) 564 | if ( 565 | state.last_reported_ts is None 566 | or state.last_reported_ts == state.last_updated_ts 567 | ): 568 | last_reported = dt_util.utc_from_timestamp(state.last_updated_ts or 0) 569 | else: 570 | last_reported = dt_util.utc_from_timestamp(state.last_reported_ts or 0) 571 | return State( 572 | state.entity_id or "", 573 | state.state, # type: ignore[arg-type] 574 | # Join the state_attributes table on attributes_id to get the attributes 575 | # for newer states 576 | attrs, 577 | last_changed=last_changed, 578 | last_reported=last_reported, 579 | last_updated=last_updated, 580 | context=context, 581 | validate_entity_id=validate_entity_id, 582 | ) 583 | 584 | 585 | def db_state_attributes_to_native(state_attrs: StateAttributes) -> dict[str, Any]: 586 | """Convert to a state attributes dictionary.""" 587 | shared_attrs = state_attrs.shared_attrs 588 | if shared_attrs is None: 589 | return {} 590 | return cast(dict[str, Any], json_loads(shared_attrs)) 591 | 592 | 593 | async def async_drop_index( 594 | recorder: Recorder, table: str, index: str, caplog: pytest.LogCaptureFixture 595 | ) -> None: 596 | """Drop an index from the database. 597 | 598 | migration._drop_index does not return or raise, so we verify the result 599 | by checking the log for success or failure messages. 600 | """ 601 | 602 | finish_msg = f"Finished dropping index `{index}` from table `{table}`" 603 | fail_msg = f"Failed to drop index `{index}` from table `{table}`" 604 | 605 | count_finish = caplog.text.count(finish_msg) 606 | count_fail = caplog.text.count(fail_msg) 607 | 608 | await recorder.async_add_executor_job( 609 | migration._drop_index, recorder.get_session, table, index 610 | ) 611 | 612 | assert caplog.text.count(finish_msg) == count_finish + 1 613 | assert caplog.text.count(fail_msg) == count_fail 614 | --------------------------------------------------------------------------------