├── static └── .keep ├── .dockerignore ├── app ├── __init__.py ├── db │ ├── __init__.py │ ├── models.py │ └── main.py ├── internal │ ├── __init__.py │ ├── command_endpoints │ │ ├── openhab.py │ │ ├── README.md │ │ ├── __init__.py │ │ ├── main.py │ │ ├── rest.py │ │ ├── mqtt.py │ │ └── ha_ws.py │ ├── client.py │ ├── wake.py │ ├── connmgr.py │ ├── config.py │ ├── notify.py │ └── was.py ├── pytest │ ├── __init__.py │ ├── mock.py │ ├── test_config.py │ ├── test_get_release_url.py │ ├── api │ │ ├── test_release.py │ │ └── test_ota.py │ └── test_construct_wis_tts_url.py ├── settings.py ├── routers │ ├── info.py │ ├── status.py │ ├── asset.py │ ├── ota.py │ ├── release.py │ ├── config.py │ └── client.py ├── const.py └── main.py ├── cache └── .gitkeep ├── storage ├── .gitkeep └── ota │ ├── .gitkeep │ └── local │ └── .gitkeep ├── migrations ├── README ├── script.py.mako ├── versions │ └── 8f14a11346c4_initial_schema.py └── env.py ├── default_nvs.json ├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ └── build-and-publish.yml ├── docker-compose.yml ├── entrypoint.sh ├── Dockerfile ├── requirements.txt ├── misc └── migrate_devices.py ├── uvicorn-log-config.json ├── default_config.json ├── README.md ├── .gitignore ├── utils.sh ├── alembic.ini └── LICENSE /static/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | work -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cache/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /storage/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /storage/ota/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/internal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /storage/ota/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /migrations/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /default_nvs.json: -------------------------------------------------------------------------------- 1 | { 2 | "WAS": { 3 | "URL": "wss://was.local" 4 | }, 5 | "WIFI": { 6 | "PSK": "mypassword", 7 | "SSID": "myssid" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /app/pytest/__init__.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(scope="session", autouse=True) 7 | def pre(): 8 | makedirs("/app/static/admin", exist_ok=True) 9 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.py] 4 | charset = utf-8 5 | end_of_line = lf 6 | indent_size = 4 7 | indent_style = space 8 | trim_trailing_whitespace = true 9 | 10 | [*.{ecrc,y{,a}ml}] 11 | indent_size = 2 12 | -------------------------------------------------------------------------------- /app/pytest/mock.py: -------------------------------------------------------------------------------- 1 | mock_releases_willow = [{ 2 | "name": "0.0.0-mock.0", 3 | "tag_name": "0.0.0-mock.0", 4 | "assets": [ 5 | { 6 | "browser_download_url": "bogus", 7 | "platform": "ESP32-S3-BOX-3", 8 | } 9 | ] 10 | }] 11 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: "pip" 5 | directory: "/" 6 | groups: 7 | pydantic-and-core: 8 | patterns: 9 | - "pydantic" 10 | - "pydantic-core" 11 | schedule: 12 | interval: "weekly" 13 | -------------------------------------------------------------------------------- /app/settings.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from pydantic_settings import BaseSettings 4 | 5 | from app.const import DB_URL 6 | 7 | 8 | class Settings(BaseSettings): 9 | db_url: str = DB_URL 10 | was_version: str = "unknown" 11 | 12 | 13 | @lru_cache 14 | def get_settings(): 15 | return Settings() 16 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | was: 3 | restart: unless-stopped 4 | image: ${IMAGE}:${TAG} 5 | environment: 6 | - LOG_LEVEL 7 | - TZ=${TZ} 8 | - WAS_IP 9 | - WAS_LOG_LEVEL 10 | #build: 11 | # dockerfile: Dockerfile 12 | ports: 13 | - ${LISTEN_IP}:${API_LISTEN_PORT}:8502 14 | volumes: 15 | - ./:/app 16 | - ./storage:/app/storage 17 | -------------------------------------------------------------------------------- /app/pytest/test_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from pathlib import Path 4 | 5 | from app.internal.config import WillowConfig 6 | 7 | 8 | def test_config(): 9 | default_config_path = Path("default_config.json") 10 | default_config = default_config_path.read_text() 11 | default_config_dict = json.loads(default_config) 12 | 13 | config = WillowConfig.model_validate(default_config_dict) 14 | 15 | assert config.model_dump(exclude_none=True) == default_config_dict 16 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/openhab.py: -------------------------------------------------------------------------------- 1 | from .rest import RestAuthType, RestConfig, RestEndpoint 2 | 3 | 4 | class OpenhabEndpoint(RestEndpoint): 5 | name = "WAS openHAB Endpoint" 6 | 7 | def __init__(self, url, token): 8 | self.config = RestConfig(auth_type=RestAuthType.BASIC, auth_user=token) 9 | self.url = f"{url}/rest/voice/interpreters" 10 | 11 | def send(self, jsondata=None, ws=None, client=None): 12 | return super().send(data=jsondata["text"]) 13 | -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | set -e 3 | 4 | # Log level - acceptable values are debug, info, warning, error, critical. Suggest info or debug. 5 | LOG_LEVEL=${LOG_LEVEL:-info} 6 | 7 | FORWARDED_ALLOW_IPS=${FORWARDED_ALLOW_IPS:-127.0.0.1} 8 | 9 | set +a 10 | 11 | python /app/misc/migrate_devices.py 12 | 13 | uvicorn app.main:app --host 0.0.0.0 --port 8502 --log-config uvicorn-log-config.json \ 14 | --log-level "$LOG_LEVEL" --loop uvloop --timeout-graceful-shutdown 5 \ 15 | --no-server-header 16 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG WAS_UI_TAG="main" 2 | 3 | FROM ghcr.io/heywillow/willow-application-server-ui:${WAS_UI_TAG} AS was-ui 4 | 5 | FROM python:3.12.9-alpine3.21 6 | 7 | 8 | WORKDIR /app 9 | 10 | RUN apk add --no-cache alpine-sdk libpq-dev 11 | 12 | COPY requirements.txt . 13 | 14 | RUN --mount=type=cache,target=/root/.cache pip install -r requirements.txt 15 | 16 | COPY . . 17 | 18 | COPY --from=was-ui /was-ui/out/ /app/static/admin/ 19 | 20 | RUN PYTHONPATH=/app pytest -s 21 | 22 | EXPOSE 8501 23 | EXPOSE 8502 24 | 25 | ARG WAS_VERSION 26 | ENV WAS_VERSION=$WAS_VERSION 27 | 28 | CMD /app/entrypoint.sh 29 | -------------------------------------------------------------------------------- /app/routers/info.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import Annotated 3 | 4 | from fastapi import APIRouter, Depends 5 | from fastapi.responses import JSONResponse 6 | 7 | from app.settings import Settings, get_settings 8 | 9 | 10 | log = getLogger("WAS") 11 | 12 | router = APIRouter( 13 | prefix="/api", 14 | ) 15 | 16 | 17 | @router.get("/info") 18 | async def api_get_info(settings: Annotated[Settings, Depends(get_settings)]): 19 | log.debug('API GET VERSION: Request') 20 | 21 | info = { 22 | 'was': { 23 | 'version': settings.was_version, 24 | } 25 | } 26 | 27 | return JSONResponse(info) 28 | -------------------------------------------------------------------------------- /app/const.py: -------------------------------------------------------------------------------- 1 | ALEMBIC_CONFIG = '/app/alembic.ini' 2 | DB_URL = 'sqlite:////app/storage/was.db' 3 | DIR_ASSET = '/app/storage/asset' 4 | DIR_OTA = '/app/storage/ota' 5 | URL_WILLOW_RELEASES = 'https://worker.heywillow.org/api/release?format=was' 6 | URL_WILLOW_CONFIG = 'https://worker.heywillow.org/api/config' 7 | URL_WILLOW_TZ = 'https://worker.heywillow.org/api/asset?type=tz' 8 | 9 | STORAGE_USER_CLIENT_CONFIG = 'storage/user_client_config.json' 10 | STORAGE_USER_CONFIG = 'storage/user_config.json' 11 | STORAGE_USER_MULTINET = 'storage/user_multinet.json' 12 | STORAGE_USER_NVS = 'storage/user_nvs.json' 13 | STORAGE_USER_WAS = 'storage/user_was.json' 14 | STORAGE_TZ = 'storage/tz.json' 15 | -------------------------------------------------------------------------------- /app/internal/client.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Client(BaseModel): 5 | hostname: str = "unknown" 6 | platform: str = "unknown" 7 | mac_addr: str = "unknown" 8 | notification_active: int = 0 9 | ua: str = None 10 | 11 | def set_hostname(self, hostname): 12 | self.hostname = hostname 13 | 14 | def set_platform(self, platform): 15 | self.platform = platform 16 | 17 | def set_mac_addr(self, mac_addr): 18 | self.mac_addr = mac_addr 19 | 20 | def is_notification_active(self): 21 | return self.notification_active != 0 22 | 23 | def set_notification_active(self, id): 24 | self.notification_active = id 25 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/README.md: -------------------------------------------------------------------------------- 1 | # Willow Application Server Command Endpoints 2 | 3 | With WAS mode enabled, Willow no longer communicates with command endpoints 4 | directly. The message flow on wake is as follows: 5 | 6 | 1. Willow HTTP POST audio --> WIS 7 | 2. WIS responds JSON to Willow 8 | 3. Willow sends unmodified JSON from WIS to WAS 9 | 4. WAS forwards to configured endpoint 10 | 5. Endpoint responds to WAS 11 | 6. WAS responds JSON to Willow 12 | 13 | The message format is fixed in steps 1,3,4,6. 14 | The message format in steps 5 and 6 depends on the configured command endpoint. 15 | 16 | Step 6 example message: 17 | 18 | { 19 | "ok": true, 20 | "speech": "turned on light" 21 | } 22 | -------------------------------------------------------------------------------- /migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | import sqlmodel 13 | ${imports if imports else ""} 14 | 15 | # revision identifiers, used by Alembic. 16 | revision: str = ${repr(up_revision)} 17 | down_revision: Union[str, None] = ${repr(down_revision)} 18 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} 19 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} 20 | 21 | 22 | def upgrade() -> None: 23 | ${upgrades if upgrades else "pass"} 24 | 25 | 26 | def downgrade() -> None: 27 | ${downgrades if downgrades else "pass"} 28 | -------------------------------------------------------------------------------- /app/pytest/test_get_release_url.py: -------------------------------------------------------------------------------- 1 | from app.internal.was import get_release_url 2 | 3 | 4 | def test_get_release_url(): 5 | expect = "http://was.local/api/ota?version=local&platform=ESP32-S3-BOX-3" 6 | assert expect == get_release_url("ws://was.local/ws", "local", "ESP32-S3-BOX-3") 7 | 8 | expect = "http://was.local:8502/api/ota?version=local&platform=ESP32-S3-BOX-3" 9 | assert expect == get_release_url("ws://was.local:8502/ws", "local", "ESP32-S3-BOX-3") 10 | 11 | expect = "https://was.local/api/ota?version=local&platform=ESP32-S3-BOX-3" 12 | assert expect == get_release_url("wss://was.local/ws", "local", "ESP32-S3-BOX-3") 13 | 14 | expect = "https://was.local:8503/api/ota?version=local&platform=ESP32-S3-BOX-3" 15 | assert expect == get_release_url("wss://was.local:8503/ws", "local", "ESP32-S3-BOX-3") 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alembic==1.16.5 2 | annotated-types==0.7.0 3 | anyio==4.12.0 4 | certifi==2025.11.12 5 | charset-normalizer==3.4.4 6 | click==8.3.1 7 | dnspython==2.8.0 8 | docopt==0.6.2 9 | email-validator==2.3.0 10 | fastapi==0.122.0 11 | h11==0.16.0 12 | httpcore==1.0.9 13 | httptools==0.7.1 14 | httpx==0.28.1 15 | idna==3.11 16 | itsdangerous==2.2.0 17 | Jinja2==3.1.6 18 | MarkupSafe==3.0.3 19 | num2words==0.5.14 20 | paho-mqtt==2.1.0 21 | pydantic==2.12.5 22 | pydantic-extra-types==2.10.6 23 | pydantic-settings==2.12.0 24 | pydantic_core==2.41.5 25 | pytest==9.0.2 26 | python-dotenv==1.2.1 27 | python-magic==0.4.27 28 | python-multipart==0.0.20 29 | PyYAML==6.0.3 30 | psycopg2==2.9.11 31 | requests==2.32.5 32 | sniffio==1.3.1 33 | sqlmodel==0.0.27 34 | starlette==0.50.0 35 | typing_extensions==4.15.0 36 | ujson==5.11.0 37 | urllib3==2.6.2 38 | uvicorn==0.38.0 39 | uvloop==0.22.1 40 | watchfiles==1.1.1 41 | websockets==15.0.1 42 | -------------------------------------------------------------------------------- /app/routers/status.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from logging import getLogger 4 | from typing import Literal 5 | 6 | from fastapi import APIRouter, Depends, Query, Request 7 | from fastapi.responses import JSONResponse 8 | from pydantic import BaseModel, Field 9 | 10 | 11 | log = getLogger("WAS") 12 | 13 | router = APIRouter( 14 | prefix="/api", 15 | ) 16 | 17 | 18 | class GetStatus(BaseModel): 19 | type: Literal['asyncio_tasks', 'connmgr', 'notify_queue'] = Field(Query(..., description='Status type')) 20 | 21 | 22 | @router.get("/status") 23 | async def api_get_status(request: Request, status: GetStatus = Depends()): 24 | log.debug('API GET STATUS: Request') 25 | res = [] 26 | 27 | if status.type == "asyncio_tasks": 28 | tasks = asyncio.all_tasks() 29 | for task in tasks: 30 | res.append(f"{task.get_name()}: {task.get_coro()}") 31 | 32 | elif status.type == "connmgr": 33 | return JSONResponse(request.app.connmgr.model_dump(exclude={})) 34 | 35 | elif status.type == "notify_queue": 36 | return JSONResponse(request.app.notify_queue.model_dump(exclude={'connmgr', 'task'})) 37 | 38 | return JSONResponse(res) 39 | -------------------------------------------------------------------------------- /misc/migrate_devices.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | STORAGE_USER_CLIENT_CONFIG = '/app/storage/user_client_config.json' 6 | STORAGE_DEVICES_CONFIG = '/app/storage/devices.json' 7 | 8 | def hex_mac(mac): 9 | if type(mac) == list: 10 | mac = '%02x:%02x:%02x:%02x:%02x:%02x' % (mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]) 11 | return mac 12 | 13 | def save_json_to_file(path, content): 14 | with open(path, "w") as config_file: 15 | config_file.write(content) 16 | config_file.close() 17 | 18 | if os.path.isfile(STORAGE_USER_CLIENT_CONFIG): 19 | sys.exit() 20 | 21 | if os.path.isfile(STORAGE_DEVICES_CONFIG): 22 | print('Migrating legacy WAS client configuration...') 23 | devices_file = open(STORAGE_DEVICES_CONFIG, "r") 24 | devices = json.load(devices_file) 25 | devices_file.close() 26 | 27 | new_devices=[] 28 | for device in devices: 29 | mac_addr = hex_mac(device["mac_addr"]) 30 | label = device["label"] 31 | user_config = {"mac_addr": mac_addr, "label": label} 32 | new_devices.append(user_config) 33 | 34 | save_json_to_file(STORAGE_USER_CLIENT_CONFIG, json.dumps(new_devices)) -------------------------------------------------------------------------------- /uvicorn-log-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "disable_existing_loggers": false, 4 | "formatters": { 5 | "default": { 6 | "()": "uvicorn.logging.DefaultFormatter", 7 | "datefmt": "%Y-%m-%dT%H:%M:%S", 8 | "fmt": "[%(asctime)s] %(levelprefix)s %(message)s", 9 | "use_colors": true 10 | }, 11 | "access": { 12 | "()": "uvicorn.logging.AccessFormatter", 13 | "datefmt": "%Y-%m-%dT%H:%M:%S", 14 | "fmt": "[%(asctime)s] %(levelprefix)s %(client_addr)s - '%(request_line)s' %(status_code)s" 15 | } 16 | }, 17 | "handlers": { 18 | "default": { 19 | "formatter": "default", 20 | "class": "logging.StreamHandler", 21 | "stream": "ext://sys.stderr" 22 | }, 23 | "access": { 24 | "formatter": "access", 25 | "class": "logging.StreamHandler", 26 | "stream": "ext://sys.stdout" 27 | } 28 | }, 29 | "loggers": { 30 | "uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": false}, 31 | "uvicorn.error": {"level": "INFO"}, 32 | "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": false} 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pydantic import BaseModel 4 | 5 | class CommandEndpointConfigException(Exception): 6 | """Raised when an the command endpoint configuration is invalid 7 | 8 | Attributes: 9 | msg -- error message 10 | """ 11 | 12 | def __init__(self, msg="Command Endpoint configuration is invalid"): 13 | self.msg = msg 14 | super().__init__(self.msg) 15 | 16 | 17 | class CommandEndpointRuntimeException(Exception): 18 | """"Raised when an exception occurs while contacting the command endpoint 19 | 20 | Attributes: 21 | msg -- error message 22 | """ 23 | 24 | def __init__(self, msg="Runtime exception occured in Command Endpoint"): 25 | self.msg = msg 26 | super().__init__(self.msg) 27 | 28 | 29 | class CommandEndpointResult(BaseModel): 30 | ok: bool = False 31 | speech: str = "Error!" 32 | 33 | def sanitize(self): 34 | self.speech = self.speech.replace("\n", " ").replace("\r", " ").lstrip() 35 | 36 | 37 | class CommandEndpointResponse(BaseModel): 38 | result: CommandEndpointResult = None 39 | 40 | def __init__(self, **kwargs): 41 | super().__init__(**kwargs) 42 | self.result.sanitize() 43 | 44 | 45 | class CommandEndpoint(): 46 | name = "WAS CommandEndpoint" 47 | log = logging.getLogger("WAS") 48 | -------------------------------------------------------------------------------- /app/db/models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Optional 3 | 4 | from sqlmodel import Field, SQLModel, UniqueConstraint 5 | 6 | 7 | class WillowConfigNamespaceType(str, Enum): 8 | WAS = "WAS" 9 | WIFI = "WIFI" 10 | 11 | 12 | class WillowConfigType(str, Enum): 13 | config = "config" 14 | multinet = "multinet" 15 | nvs = "nvs" 16 | 17 | 18 | class WillowConfigTable(SQLModel, table=True): 19 | # work around probably SQLModel bug during select 20 | # AttributeError: 'ConfigTable' object has no attribute '__pydantic_extra__'. Did you mean: '__pydantic_private__'? 21 | __pydantic_extra__ = None 22 | __table_args__ = (UniqueConstraint("config_type", "config_name"), ) 23 | __tablename__ = "willow_config" 24 | 25 | id: Optional[int] = Field(default=None, primary_key=True) 26 | config_type: WillowConfigType 27 | config_name: str 28 | config_namespace: Optional[WillowConfigNamespaceType] = None 29 | config_value: Optional[str] = None 30 | 31 | 32 | class WillowClientTable(SQLModel, table=True): 33 | # work around probably SQLModel bug during select 34 | # AttributeError: 'ConfigTable' object has no attribute '__pydantic_extra__'. Did you mean: '__pydantic_private__'? 35 | __pydantic_extra__ = None 36 | __tablename__ = "willow_clients" 37 | 38 | id: Optional[int] = Field(default=None, primary_key=True) 39 | mac_addr: str = Field(unique=True) 40 | label: str 41 | -------------------------------------------------------------------------------- /app/internal/wake.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | 5 | from logging import getLogger 6 | from uuid import uuid4 7 | 8 | 9 | log = getLogger("WAS") 10 | 11 | 12 | class WakeEvent: 13 | def __init__(self, client, volume): 14 | self.client = client 15 | self.volume = volume 16 | 17 | 18 | class WakeSession: 19 | def __init__(self): 20 | self.done = False 21 | self.events = [] 22 | self.id = uuid4() 23 | self.ts = time.time() 24 | log.debug(f"WakeSession with ID {self.id} created") 25 | 26 | def add_event(self, event): 27 | log.debug(f"WakeSession {self.id} adding event {event}") 28 | self.events.append(event) 29 | 30 | async def cleanup(self, timeout=400): 31 | await asyncio.sleep(timeout / 1000) 32 | max_volume = -1000.0 33 | winner = None 34 | for event in self.events: 35 | if event.volume > max_volume: 36 | max_volume = event.volume 37 | winner = event.client 38 | 39 | # notify winner first 40 | await winner.send_text(json.dumps({'wake_result': {'won': True}})) 41 | 42 | for event in self.events: 43 | if event.client != winner: 44 | await event.client.send_text(json.dumps({'wake_result': {'won': False}})) 45 | 46 | log.debug(f"Marking WakeSession with ID {self.id} done. Winner: {winner}") 47 | self.done = True 48 | -------------------------------------------------------------------------------- /default_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "aec": true, 3 | "audio_codec": "PCM", 4 | "audio_response_type": "TTS", 5 | "bss": false, 6 | "command_endpoint": "Home Assistant", 7 | "display_timeout": 10, 8 | "hass_host": "homeassistant.local", 9 | "hass_port": 8123, 10 | "hass_tls": false, 11 | "hass_token": "your_ha_token", 12 | "lcd_brightness": 500, 13 | "mic_gain": 14, 14 | "mqtt_auth_type": "userpw", 15 | "mqtt_host": "your.mqtt.host", 16 | "mqtt_password": "your_mqtt_password", 17 | "mqtt_port": 1883, 18 | "mqtt_tls": false, 19 | "mqtt_topic": "your_mqtt_topic", 20 | "mqtt_username": "your_mqtt_username", 21 | "multiwake": false, 22 | "ntp_config": "Host", 23 | "ntp_host": "pool.ntp.org", 24 | "openhab_token": "your_openhab_token", 25 | "openhab_url": "your_openhab_url", 26 | "record_buffer": 12, 27 | "rest_auth_header": "your_header", 28 | "rest_auth_pass": "your_password", 29 | "rest_auth_type": "None", 30 | "rest_auth_user": "your_username", 31 | "rest_url": "http://your_rest_url", 32 | "show_prereleases": false, 33 | "speaker_volume": 60, 34 | "speech_rec_mode": "WIS", 35 | "stream_timeout": 5, 36 | "timezone": "UTC+5", 37 | "timezone_name": "America/Menominee", 38 | "vad_mode": 2, 39 | "vad_timeout": 300, 40 | "wake_confirmation": false, 41 | "wake_mode": "2CH_95", 42 | "wake_word": "alexa", 43 | "wis_tts_url": "https://infer.tovera.io/api/tts", 44 | "wis_url": "https://infer.tovera.io/api/willow" 45 | } 46 | -------------------------------------------------------------------------------- /app/pytest/api/test_release.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | from unittest.mock import patch 5 | 6 | from fastapi.testclient import TestClient 7 | 8 | from app.main import app 9 | from app.pytest.mock import mock_releases_willow 10 | 11 | 12 | client = TestClient(app) 13 | 14 | 15 | mock_releases_was = [{ 16 | "name": "0.0.0-mock.0", 17 | "tag_name": "0.0.0-mock.0", 18 | "assets": [ 19 | { 20 | "browser_download_url": "bogus", 21 | "platform": "ESP32-S3-BOX-3", 22 | "was_url": "http://was.local/api/ota?version=0.0.0-mock.0&platform=ESP32-S3-BOX-3", 23 | "cached": True 24 | } 25 | ] 26 | }] 27 | 28 | 29 | class TestRelease(unittest.TestCase): 30 | def test_get_release(self): 31 | with patch("app.routers.release.get_was_url", return_value=None): 32 | response = client.get("/api/release?type=was") 33 | 34 | assert response.status_code == 500 35 | 36 | with patch("app.routers.release.get_releases_willow", return_value=mock_releases_willow): 37 | response = client.get("/api/release?type=willow") 38 | 39 | assert response.status_code == 200 40 | assert json.loads(response.content) == mock_releases_willow 41 | 42 | with patch("app.routers.release.get_was_url", return_value="ws://was.local/ws"): 43 | response = client.get("/api/release?type=was") 44 | 45 | assert response.status_code == 200 46 | assert json.loads(response.content) == mock_releases_was 47 | -------------------------------------------------------------------------------- /app/routers/asset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fastapi import APIRouter, Depends, HTTPException, Query 4 | from fastapi.responses import FileResponse 5 | from logging import getLogger 6 | from typing import Literal 7 | from pydantic import BaseModel, Field 8 | 9 | from ..const import DIR_ASSET 10 | from ..internal.was import get_mime_type, get_safe_path 11 | 12 | 13 | log = getLogger("WAS") 14 | router = APIRouter( 15 | prefix="/api", 16 | ) 17 | 18 | 19 | class GetAsset(BaseModel): 20 | asset: str = Field(Query(..., description="Asset")) 21 | type: Literal["audio", "image", "other"] = Field( 22 | Query(..., description="Asset type") 23 | ) 24 | 25 | 26 | @router.get("/asset") 27 | async def api_get_asset(asset: GetAsset = Depends()): 28 | log.debug("API GET ASSET: Request") 29 | asset_file = os.path.join(DIR_ASSET, asset.type, asset.asset) 30 | asset_file = get_safe_path(DIR_ASSET, asset_file) 31 | log.debug(f"asset file: {asset_file}") 32 | 33 | # If we don't have the asset file return 404 34 | if not os.path.isfile(asset_file): 35 | raise HTTPException(status_code=404, detail="Asset File Not Found") 36 | 37 | # Use libmagic to determine MIME type to be really sure 38 | magic_mime_type = get_mime_type(asset_file) 39 | 40 | # Return image and other types 41 | if asset.type == "image" or asset.type == "other": 42 | return FileResponse(asset_file, media_type=magic_mime_type) 43 | 44 | # Only support audio formats supported by Willow 45 | if magic_mime_type == "audio/flac" or magic_mime_type == "audio/x-wav": 46 | return FileResponse(asset_file, media_type=magic_mime_type) 47 | else: 48 | raise HTTPException(status_code=400, detail="unsupported Audio Asset file format") 49 | -------------------------------------------------------------------------------- /app/routers/ota.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from logging import getLogger 4 | from pathlib import Path 5 | 6 | from fastapi import APIRouter, Depends, HTTPException, Query 7 | from fastapi.responses import FileResponse 8 | from pydantic import BaseModel, Field 9 | from requests import get 10 | 11 | from ..const import DIR_OTA 12 | from ..internal.was import get_releases_willow, get_safe_path 13 | 14 | 15 | log = getLogger("WAS") 16 | router = APIRouter(prefix="/api") 17 | 18 | 19 | class GetOta(BaseModel): 20 | version: str = Field(Query(..., description='OTA Version')) 21 | platform: str = Field(Query(..., description='OTA Platform')) 22 | 23 | 24 | @router.get("/ota") 25 | async def api_get_ota(ota: GetOta = Depends()): 26 | log.debug('API GET OTA: Request') 27 | ota_dir = get_safe_path(DIR_OTA, os.path.join(DIR_OTA, ota.version)) 28 | ota_file = os.path.join(ota_dir, f"{ota.platform}.bin") 29 | ota_file = get_safe_path(ota_dir, ota_file) 30 | if not ota_file: 31 | return 32 | if not os.path.isfile(ota_file): 33 | releases = get_releases_willow() 34 | for release in releases: 35 | if release["name"] == ota.version: 36 | assets = release["assets"] 37 | for asset in assets: 38 | if asset["platform"] == ota.platform: 39 | Path(ota_dir).mkdir(parents=True, exist_ok=True) 40 | r = get(asset["browser_download_url"]) 41 | open(ota_file, 'wb').write(r.content) 42 | 43 | # If we still don't have the file return 404 - the platform and/or version doesn't exist 44 | if not os.path.isfile(ota_file): 45 | raise HTTPException(status_code=404, detail="OTA File Not Found") 46 | 47 | return FileResponse(ota_file) 48 | -------------------------------------------------------------------------------- /app/pytest/api/test_ota.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | from unittest.mock import MagicMock, patch 5 | 6 | from fastapi.testclient import TestClient 7 | 8 | from app.main import app 9 | from app.pytest.mock import mock_releases_willow 10 | 11 | client = TestClient(app) 12 | 13 | 14 | class TestOta(unittest.TestCase): 15 | def test_get_ota(self): 16 | mock_uri_bad = "/api/ota?platform=ESP32-S3-BOX-3&version=0.0.0-mock.0/../../.." 17 | mock_uri_good = "/api/ota?platform=ESP32-S3-BOX-3&version=0.0.0-mock.0" 18 | 19 | mock_response = MagicMock() 20 | mock_response.content = b"mocked data" 21 | 22 | with patch("app.routers.ota.get_releases_willow", return_value=mock_releases_willow): 23 | # patch os.path.isfile so we call get_releases_willow() 24 | # this will write mock_response to storage/ota/0.0.0-mock.0/ESP32-S3-BOX-3.bin 25 | with patch("os.path.isfile", return_value=False): 26 | with patch("app.routers.ota.get", return_value=mock_response): 27 | response = client.get(mock_uri_good) 28 | 29 | assert response.status_code == 404 30 | 31 | # os.path.isfile is not patched here so we serve the file content 32 | response = client.get(mock_uri_good) 33 | 34 | assert response.status_code == 200 35 | assert response.content == b"mocked data" 36 | 37 | # use bad URL that tries to read data outside OTA_DIR 38 | response = client.get(mock_uri_bad) 39 | 40 | json_response = json.loads(response.content) 41 | print(json_response) 42 | 43 | assert response.status_code == 400 44 | assert json_response['detail'].startswith("invalid asset path") 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() -------------------------------------------------------------------------------- /migrations/versions/8f14a11346c4_initial_schema.py: -------------------------------------------------------------------------------- 1 | """initial schema 2 | 3 | Revision ID: 8f14a11346c4 4 | Revises: 5 | Create Date: 2023-12-08 19:03:54.320856 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | import sqlmodel 13 | 14 | 15 | # revision identifiers, used by Alembic. 16 | revision: str = '8f14a11346c4' 17 | down_revision: Union[str, None] = None 18 | branch_labels: Union[str, Sequence[str], None] = None 19 | depends_on: Union[str, Sequence[str], None] = None 20 | 21 | 22 | def upgrade() -> None: 23 | # ### commands auto generated by Alembic - please adjust! ### 24 | op.create_table('willow_clients', 25 | sa.Column('id', sa.Integer(), nullable=False), 26 | sa.Column('mac_addr', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 27 | sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 28 | sa.PrimaryKeyConstraint('id'), 29 | sa.UniqueConstraint('mac_addr') 30 | ) 31 | op.create_table('willow_config', 32 | sa.Column('id', sa.Integer(), nullable=False), 33 | sa.Column('config_type', sa.Enum('config', 'multinet', 'nvs', name='willowconfigtype'), nullable=False), 34 | sa.Column('config_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 35 | sa.Column('config_namespace', sa.Enum('WAS', 'WIFI', name='willowconfignamespacetype'), nullable=True), 36 | sa.Column('config_value', sqlmodel.sql.sqltypes.AutoString(), nullable=True), 37 | sa.PrimaryKeyConstraint('id'), 38 | sa.UniqueConstraint('config_type', 'config_name') 39 | ) 40 | # ### end Alembic commands ### 41 | 42 | 43 | def downgrade() -> None: 44 | # ### commands auto generated by Alembic - please adjust! ### 45 | op.drop_table('willow_config') 46 | op.drop_table('willow_clients') 47 | # ### end Alembic commands ### 48 | -------------------------------------------------------------------------------- /app/routers/release.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from logging import getLogger 4 | from pathlib import Path 5 | from typing import Literal 6 | 7 | from fastapi import APIRouter, Depends, HTTPException, Query, Request 8 | from fastapi.responses import JSONResponse 9 | from pydantic import BaseModel, Field 10 | from requests import get 11 | 12 | from ..const import DIR_OTA 13 | from ..internal.was import get_release_url, get_releases_willow, get_safe_path, get_was_url 14 | 15 | 16 | log = getLogger("WAS") 17 | router = APIRouter(prefix="/api") 18 | 19 | 20 | class GetRelease(BaseModel): 21 | type: Literal['was', 'willow'] = Field(Query(..., description='Release type')) 22 | 23 | 24 | @router.get("/release") 25 | async def api_get_release(release: GetRelease = Depends()): 26 | log.debug('API GET RELEASE: Request') 27 | releases = get_releases_willow() 28 | if release.type == "willow": 29 | return releases 30 | elif release.type == "was": 31 | was_url = get_was_url() 32 | if not was_url: 33 | raise HTTPException(status_code=500, detail="WAS URL not set") 34 | 35 | try: 36 | for release in releases: 37 | tag_name = release["tag_name"] 38 | assets = release["assets"] 39 | for asset in assets: 40 | platform = asset["platform"] 41 | asset["was_url"] = get_release_url(was_url, tag_name, platform) 42 | ota_path = os.path.join(DIR_OTA, tag_name, f"{platform}.bin") 43 | if os.path.isfile(ota_path): 44 | asset["cached"] = True 45 | else: 46 | asset["cached"] = False 47 | except Exception as e: 48 | log.error(e) 49 | pass 50 | 51 | return JSONResponse(content=releases) 52 | 53 | 54 | class PostRelease(BaseModel): 55 | action: Literal['delete'] = Field(Query(..., description='Release Cache Control')) 56 | 57 | 58 | @router.post("/release") 59 | async def api_post_release(request: Request, release: PostRelease = Depends()): 60 | log.debug('API POST RELEASE: Request') 61 | 62 | if release.action == "delete": 63 | data = await request.json() 64 | path = data['path'] 65 | path = get_safe_path(DIR_OTA, path) 66 | if path: 67 | os.remove(path) 68 | -------------------------------------------------------------------------------- /.github/workflows/build-and-publish.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: build_container 3 | 4 | env: 5 | GITHUB_TOKEN: ${{ github.token }} 6 | REGISTRY: ghcr.io 7 | IMAGE_NAME: ${{ github.repository }} 8 | 9 | on: 10 | pull_request: 11 | push: 12 | branches: 13 | - 'main' 14 | - 'release/**' 15 | tag: 16 | - '*' 17 | 18 | jobs: 19 | build_container: 20 | runs-on: ubuntu-22.04 21 | permissions: 22 | packages: write 23 | steps: 24 | - name: checkout 25 | uses: actions/checkout@v3 26 | with: 27 | fetch-depth: 0 28 | fetch-tags: true 29 | 30 | - name: configure QEMU 31 | uses: docker/setup-qemu-action@v2 32 | with: 33 | platforms: 'arm64' 34 | 35 | - name: configure buildx 36 | uses: docker/setup-buildx-action@v2 37 | 38 | - name: login to ghcr.io 39 | uses: docker/login-action@v2 40 | if: ${{ github.event_name != 'pull_request' }} 41 | with: 42 | registry: ${{ env.REGISTRY }} 43 | username: ${{ github.actor }} 44 | password: ${{ secrets.GITHUB_TOKEN }} 45 | 46 | - name: extract metadata 47 | id: metadata 48 | uses: docker/metadata-action@v4 49 | with: 50 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 51 | tags: | 52 | type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/0.1.') }} 53 | type=raw,value=rc,enable=${{ contains(github.ref, '-rc.') }} 54 | type=semver,pattern={{version}} 55 | type=ref,event=branch 56 | type=raw,value=main,enable=true 57 | 58 | - name: extract tag name 59 | id: tag 60 | run: | 61 | tag=${{ fromJSON(steps.metadata.outputs.json).tags[0] }} 62 | echo "WAS_UI_TAG=${tag##*:}" >> $GITHUB_OUTPUT 63 | 64 | - name: git describe 65 | run: echo "WAS_VERSION=$(git describe --always --dirty --tags)" >> $GITHUB_OUTPUT 66 | id: gd 67 | 68 | - name: build container 69 | uses: docker/build-push-action@v4 70 | with: 71 | build-args: | 72 | WAS_UI_TAG=${{ steps.tag.outputs.WAS_UI_TAG }} 73 | WAS_VERSION=${{ steps.gd.outputs.WAS_VERSION }} 74 | context: . 75 | file: ./Dockerfile 76 | platforms: linux/amd64,linux/arm64 77 | push: ${{ github.event_name == 'push' && github.actor != 'dependabot[bot]' }} 78 | labels: ${{ steps.metadata.outputs.labels }} 79 | tags: ${{ steps.metadata.outputs.tags }} 80 | -------------------------------------------------------------------------------- /app/pytest/test_construct_wis_tts_url.py: -------------------------------------------------------------------------------- 1 | from app.internal.was import construct_wis_tts_url 2 | 3 | 4 | def test_construct_wis_tts_url(): 5 | expect = "http://wis.local/api/tts?text=" 6 | assert expect == construct_wis_tts_url("http://wis.local/api/tts") 7 | assert expect == construct_wis_tts_url("http://wis.local/api/tts?text") 8 | assert expect == construct_wis_tts_url("http://wis.local/api/tts?text=") 9 | 10 | expect = "http://wis.local/api/tts?bar=baz&text=" 11 | assert expect == construct_wis_tts_url("http://wis.local/api/tts?text&bar=baz") 12 | assert expect == construct_wis_tts_url("http://wis.local/api/tts?text=&bar=baz") 13 | assert expect == construct_wis_tts_url("http://wis.local/api/tts?text=foo&bar=baz") 14 | 15 | expect = "http://wis.local:19000/api/tts?text=" 16 | assert expect == construct_wis_tts_url("http://wis.local:19000/api/tts") 17 | assert expect == construct_wis_tts_url("http://wis.local:19000/api/tts?text") 18 | assert expect == construct_wis_tts_url("http://wis.local:19000/api/tts?text=") 19 | 20 | expect = "http://wis.local:19000/api/tts?bar=baz&text=" 21 | assert expect == construct_wis_tts_url("http://wis.local:19000/api/tts?text&bar=baz") 22 | assert expect == construct_wis_tts_url("http://wis.local:19000/api/tts?text=&bar=baz") 23 | assert expect == construct_wis_tts_url("http://wis.local:19000/api/tts?text=foo&bar=baz") 24 | 25 | expect = "http://user:pass@wis.local/api/tts?text=" 26 | assert expect == construct_wis_tts_url("http://user:pass@wis.local/api/tts") 27 | assert expect == construct_wis_tts_url("http://user:pass@wis.local/api/tts?text") 28 | assert expect == construct_wis_tts_url("http://user:pass@wis.local/api/tts?text=") 29 | 30 | expect = "http://user:pass@wis.local/api/tts?bar=baz&text=" 31 | assert expect == construct_wis_tts_url("http://user:pass@wis.local/api/tts?text&bar=baz") 32 | assert expect == construct_wis_tts_url("http://user:pass@wis.local/api/tts?text=&bar=baz") 33 | assert expect == construct_wis_tts_url("http://user:pass@wis.local/api/tts?text=foo&bar=baz") 34 | 35 | expect = "http://user:pass@wis.local:19000/api/tts?text=" 36 | assert expect == construct_wis_tts_url("http://user:pass@wis.local:19000/api/tts") 37 | assert expect == construct_wis_tts_url("http://user:pass@wis.local:19000/api/tts?text") 38 | assert expect == construct_wis_tts_url("http://user:pass@wis.local:19000/api/tts?text=") 39 | 40 | expect = "http://user:pass@wis.local:19000/api/tts?bar=baz&text=" 41 | assert expect == construct_wis_tts_url("http://user:pass@wis.local:19000/api/tts?text&bar=baz") 42 | assert expect == construct_wis_tts_url("http://user:pass@wis.local:19000/api/tts?text=&bar=baz") 43 | assert expect == construct_wis_tts_url("http://user:pass@wis.local:19000/api/tts?text=foo&bar=baz") 44 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/main.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | from app.db.main import get_config_db 4 | from app.internal.command_endpoints.ha_ws import ( 5 | HomeAssistantWebSocketEndpoint, 6 | ) 7 | from app.internal.command_endpoints.mqtt import MqttConfig, MqttEndpoint 8 | from app.internal.command_endpoints.openhab import OpenhabEndpoint 9 | from app.internal.command_endpoints.rest import RestEndpoint 10 | 11 | 12 | log = getLogger("WAS") 13 | 14 | 15 | def init_command_endpoint(app): 16 | # call command_endpoint.stop() to avoid leaking asyncio task 17 | try: 18 | app.command_endpoint.stop() 19 | except Exception: 20 | pass 21 | 22 | user_config = get_config_db() 23 | 24 | if "was_mode" in user_config and user_config["was_mode"]: 25 | log.info("WAS Endpoint mode enabled") 26 | 27 | if user_config["command_endpoint"] == "Home Assistant": 28 | 29 | host = user_config["hass_host"] 30 | port = user_config["hass_port"] 31 | tls = user_config["hass_tls"] 32 | token = user_config["hass_token"] 33 | 34 | app.command_endpoint = HomeAssistantWebSocketEndpoint(app, host, port, tls, token) 35 | 36 | elif user_config["command_endpoint"] == "MQTT": 37 | mqtt_config = MqttConfig() 38 | mqtt_config.set_auth_type(user_config["mqtt_auth_type"]) 39 | mqtt_config.set_hostname(user_config["mqtt_host"]) 40 | mqtt_config.set_port(user_config["mqtt_port"]) 41 | mqtt_config.set_tls(user_config["mqtt_tls"]) 42 | mqtt_config.set_topic(user_config["mqtt_topic"]) 43 | 44 | if 'mqtt_password' in user_config: 45 | mqtt_config.set_password(user_config['mqtt_password']) 46 | 47 | if 'mqtt_username' in user_config: 48 | mqtt_config.set_username(user_config['mqtt_username']) 49 | 50 | app.command_endpoint = MqttEndpoint(mqtt_config) 51 | 52 | elif user_config["command_endpoint"] == "openHAB": 53 | app.command_endpoint = OpenhabEndpoint(user_config["openhab_url"], user_config["openhab_token"]) 54 | 55 | elif user_config["command_endpoint"] == "REST": 56 | app.command_endpoint = RestEndpoint(user_config["rest_url"]) 57 | 58 | if hasattr(user_config, "rest_auth_type"): 59 | app.command_endpoint.config.set_auth_type(user_config["rest_auth_type"]) 60 | 61 | if "rest_auth_header" in user_config: 62 | app.command_endpoint.config.set_auth_header(user_config["rest_auth_header"]) 63 | 64 | if "rest_auth_pass" in user_config: 65 | app.command_endpoint.config.set_auth_pass(user_config["rest_auth_pass"]) 66 | 67 | if "rest_auth_user" in user_config: 68 | app.command_endpoint.config.set_auth_user(user_config["rest_auth_user"]) 69 | -------------------------------------------------------------------------------- /migrations/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config 4 | from sqlalchemy import pool 5 | from sqlmodel import SQLModel 6 | 7 | from alembic import context 8 | 9 | from app.const import DB_URL 10 | from app.db.models import WillowClientTable 11 | from app.db.models import WillowConfigTable 12 | from app.settings import get_settings 13 | 14 | 15 | # this is the Alembic Config object, which provides 16 | # access to the values within the .ini file in use. 17 | config = context.config 18 | 19 | settings = get_settings() 20 | 21 | config.set_main_option('sqlalchemy.url', settings.db_url) 22 | 23 | # Interpret the config file for Python logging. 24 | # This line sets up loggers basically. 25 | # We must skip this when config has a logger attribute to prevent the WAS logger config from being overwritten 26 | if config.config_file_name is not None and config.attributes.get('logger', None) is None: 27 | fileConfig(config.config_file_name) 28 | 29 | # add your model's MetaData object here 30 | # for 'autogenerate' support 31 | # from myapp import mymodel 32 | # target_metadata = mymodel.Base.metadata 33 | target_metadata = SQLModel.metadata 34 | 35 | # other values from the config, defined by the needs of env.py, 36 | # can be acquired: 37 | # my_important_option = config.get_main_option("my_important_option") 38 | # ... etc. 39 | 40 | 41 | def run_migrations_offline() -> None: 42 | """Run migrations in 'offline' mode. 43 | 44 | This configures the context with just a URL 45 | and not an Engine, though an Engine is acceptable 46 | here as well. By skipping the Engine creation 47 | we don't even need a DBAPI to be available. 48 | 49 | Calls to context.execute() here emit the given string to the 50 | script output. 51 | 52 | """ 53 | url = config.get_main_option("sqlalchemy.url") 54 | context.configure( 55 | url=url, 56 | target_metadata=target_metadata, 57 | literal_binds=True, 58 | dialect_opts={"paramstyle": "named"}, 59 | ) 60 | 61 | with context.begin_transaction(): 62 | context.run_migrations() 63 | 64 | 65 | def run_migrations_online() -> None: 66 | """Run migrations in 'online' mode. 67 | 68 | In this scenario we need to create an Engine 69 | and associate a connection with the context. 70 | 71 | """ 72 | connectable = engine_from_config( 73 | config.get_section(config.config_ini_section, {}), 74 | prefix="sqlalchemy.", 75 | poolclass=pool.NullPool, 76 | ) 77 | 78 | with connectable.connect() as connection: 79 | context.configure( 80 | connection=connection, target_metadata=target_metadata 81 | ) 82 | 83 | with context.begin_transaction(): 84 | context.run_migrations() 85 | 86 | 87 | if context.is_offline_mode(): 88 | run_migrations_offline() 89 | else: 90 | run_migrations_online() 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Willow Application Server 2 | 3 | ## Get Started 4 | 5 | We have tried to simplify the onboarding process as much as possible. It is no longer required to build Willow yourself. 6 | All you have to do is run Willow Application Server and connect to it. From there, you will be guided to the Willow Web Flasher, which will download a Willow release image from Github, inject your Wi-Fi credentials and WAS URL into the NVS partition, and flash it to your device. 7 | 8 | ### Running WAS 9 | 10 | ``` 11 | docker run --detach --name=willow-application-server --env TZ="Europe/Sofia" --pull=always --network=host --restart=unless-stopped --volume=was-storage:/app/storage ghcr.io/heywillow/willow-application-server 12 | ``` 13 | 14 | ### Building WAS 15 | ``` 16 | git clone https://github.com/HeyWillow/willow-application-server.git && cd willow-application-server 17 | 18 | ./utils.sh build 19 | ``` 20 | 21 | ### Start 22 | ```./utils.sh run``` 23 | 24 | ## Configure and Upgrade Willow Devices 25 | Visit ```http://my_was_host:8502``` in your browser. 26 | 27 | ## Upgrading "Over the Air" (OTA) 28 | 29 | OTA upgrades allow you to update Willow devices without having to re-connect them to your computer to flash. It's a very safe process with a great deal of verification, automatic rollbacks on upgrade failure, etc. 30 | 31 | We list published releases with OTA assets. Select the desired release and click the "Upgrade" button. If the release is not already cached in WAS, WAS will download the binary from Github and cache it, then instruct the target Willow device to start the upgrade from your running WAS instance. Alternatively, you can just upgrade the device from the clients page and WAS will cache it automatically on first request. This makes it possible to run Willow in an isolated VLAN without Internet access. 32 | 33 | ### Upgrade with your own Willow builds 34 | 35 | After building with Willow you can provide your build to WAS to upgrade your local devices using OTA. 36 | 37 | Make sure you select the appropriate target hardware from the "Audio HAL" section during build. Then run `./utils.sh build` as you normally would. 38 | 39 | To use your custom binary for OTA place `build/willow.bin` from your Willow build directory in the `ota/local` directory of the was-storage volume using the following filenames: 40 | 41 | * ESP32-S3-BOX-3.bin 42 | * ESP32-S3-BOX.bin 43 | * ESP32-S3-BOX-LITE.bin 44 | 45 | To copy the file to your WAS instance: 46 | ``` 47 | docker cp build/willow.bin willow-application-server:/app/storage/ota/local/ESP32-S3-BOX-3.bin 48 | ``` 49 | 50 | Your provided build will now be available as the "local" release under the various upgrade options available in the Willow Web UI. You can copy new builds and upgrade however you see fit as you do development and create new Willow builds. If you run into a boot loop, bad flash, etc you can always recover your device from the Willow Web Flasher or Willow build system and try again. 51 | -------------------------------------------------------------------------------- /app/internal/connmgr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from fastapi import ( 4 | WebSocket, 5 | WebSocketException, 6 | ) 7 | from pydantic import BaseModel, ConfigDict, FieldSerializationInfo, SerializerFunctionWrapHandler, field_serializer 8 | from typing import Dict 9 | 10 | from .client import Client 11 | 12 | 13 | log = logging.getLogger("WAS") 14 | 15 | 16 | class ConnMgr(BaseModel): 17 | model_config = ConfigDict(arbitrary_types_allowed=True) 18 | 19 | connected_clients: Dict[WebSocket, Client] = {} 20 | 21 | @field_serializer('connected_clients', mode='wrap') 22 | def serialize_connected_clients(self, value: Dict[WebSocket, Client], nxt: SerializerFunctionWrapHandler, info: FieldSerializationInfo) -> Dict[str, Client]: 23 | serialized_dict = {} 24 | for ws, client in value.items(): 25 | remote = f"{ws.client.host}:{ws.client.port}" 26 | serialized_dict[remote] = client 27 | return nxt(serialized_dict) 28 | 29 | async def accept(self, ws: WebSocket, client: Client): 30 | try: 31 | await ws.accept() 32 | self.connected_clients[ws] = client 33 | except WebSocketException as e: 34 | log.error(f"Failed to accept websocket connection: {e}") 35 | 36 | async def broadcast(self, msg: str): 37 | for client in self.connected_clients: 38 | try: 39 | await client.send_text(msg) 40 | except Exception as e: 41 | log.error(f"Failed to broadcast message: {e}") 42 | 43 | def disconnect(self, ws: WebSocket): 44 | if ws in self.connected_clients: 45 | self.connected_clients.pop(ws) 46 | 47 | def get_client_by_hostname(self, hostname): 48 | for k, v in self.connected_clients.items(): 49 | if v.hostname == hostname: 50 | return k 51 | 52 | def get_client_by_ws(self, ws): 53 | return self.connected_clients[ws] 54 | 55 | def get_mac_by_hostname(self, hostname): 56 | for k, v in self.connected_clients.items(): 57 | if v.hostname == hostname: 58 | return v.mac_addr 59 | 60 | return None 61 | 62 | def get_ws_by_mac(self, mac): 63 | for k, v in self.connected_clients.items(): 64 | # log.debug(f"get_ws_by_mac: {k} {v.mac_addr}") 65 | if v.mac_addr == mac: 66 | return k 67 | 68 | log.debug("get_ws_by_mac: returning None") 69 | return None 70 | 71 | def is_notification_active(self, ws): 72 | return self.connected_clients[ws].is_notification_active() 73 | 74 | def set_notification_active(self, ws, id): 75 | self.connected_clients[ws].set_notification_active(id) 76 | 77 | def update_client(self, ws, key, value): 78 | if key == "hostname": 79 | self.connected_clients[ws].set_hostname(value) 80 | elif key == "platform": 81 | self.connected_clients[ws].set_platform(value) 82 | elif key == "mac_addr": 83 | self.connected_clients[ws].set_mac_addr(value) 84 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/rest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from . import ( 4 | CommandEndpoint, 5 | CommandEndpointConfigException, 6 | CommandEndpointResponse, 7 | CommandEndpointResult, 8 | CommandEndpointRuntimeException 9 | ) 10 | from enum import Enum 11 | from requests import request 12 | from requests.auth import HTTPBasicAuth 13 | 14 | 15 | class RestAuthType(Enum): 16 | NONE = 1 17 | BASIC = 2 18 | HEADER = 3 19 | 20 | 21 | class RestConfig(): 22 | auth_header: str = "" 23 | auth_pass: str = "" 24 | auth_type: Enum = RestAuthType.NONE 25 | auth_user: str = "" 26 | 27 | log = logging.getLogger("WAS") 28 | 29 | def __init__(self, auth_type=RestAuthType.NONE, auth_header="", auth_pass="", auth_user=""): 30 | self.auth_header = auth_header 31 | self.auth_pass = auth_pass 32 | self.auth_type = auth_type 33 | self.auth_user = auth_user 34 | 35 | def set_auth_header(self, auth_header=""): 36 | self.log.debug(f"setting auth header: {auth_header}") 37 | self.auth_header = auth_header 38 | 39 | def set_auth_pass(self, auth_pass=""): 40 | self.log.debug(f"setting auth password: {auth_pass}") 41 | self.auth_pass = auth_pass 42 | 43 | def set_auth_type(self, auth_type=RestAuthType.NONE): 44 | self.log.debug(f"setting auth type: {auth_type}") 45 | self.auth_type = RestAuthType[auth_type.upper()] 46 | 47 | def set_auth_user(self, auth_user=""): 48 | self.log.debug(f"setting auth username: {auth_user}") 49 | self.auth_user = auth_user 50 | 51 | 52 | class RestEndpoint(CommandEndpoint): 53 | name = "REST" 54 | 55 | def __init__(self, url): 56 | self.config = RestConfig() 57 | self.url = url 58 | 59 | def parse_response(self, response): 60 | res = CommandEndpointResult() 61 | if response.ok: 62 | res.ok = True 63 | if len(res.speech) > 0: 64 | res.speech = response.text 65 | else: 66 | res.speech = "Success!" 67 | 68 | command_endpoint_response = CommandEndpointResponse(result=res) 69 | return command_endpoint_response.model_dump_json() 70 | 71 | def send(self, data=None, jsondata=None, ws=None, client=None): 72 | try: 73 | basic = None 74 | headers = {} 75 | 76 | if jsondata is not None: 77 | headers['Content-Type'] = 'application/json' 78 | else: 79 | headers['Content-Type'] = 'text/plain' 80 | 81 | if self.config.auth_type == RestAuthType.BASIC: 82 | basic = HTTPBasicAuth(self.config.auth_user, self.config.auth_pass) 83 | elif self.config.auth_type == RestAuthType.HEADER: 84 | headers['Authorization'] = self.config.auth_header 85 | elif self.config.auth_type == RestAuthType.NONE: 86 | pass 87 | else: 88 | raise CommandEndpointConfigException("invalid REST auth type") 89 | 90 | return request("POST", self.url, auth=basic, data=data, headers=headers, json=jsondata, timeout=(1, 30)) 91 | 92 | except Exception as e: 93 | raise CommandEndpointRuntimeException(e) 94 | -------------------------------------------------------------------------------- /app/routers/config.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from re import sub 3 | from typing import Literal, Optional 4 | 5 | from fastapi import APIRouter, Depends, HTTPException, Query, Request 6 | from fastapi.responses import JSONResponse, PlainTextResponse 7 | from pydantic import BaseModel, Field 8 | from requests import get 9 | 10 | from app.db.main import get_config_db, get_nvs_db 11 | 12 | from ..const import URL_WILLOW_CONFIG 13 | from ..internal.command_endpoints.main import init_command_endpoint 14 | from ..internal.was import ( 15 | construct_url, 16 | get_multinet, 17 | get_tz_config, 18 | get_was_config, 19 | post_config, 20 | post_nvs, 21 | post_was, 22 | ) 23 | 24 | 25 | log = getLogger("WAS") 26 | router = APIRouter(prefix="/api") 27 | 28 | 29 | class GetConfig(BaseModel): 30 | type: Literal['config', 'nvs', 'ha_url', 'ha_token', 'multinet', 'was', 'tz'] = Field( 31 | Query(..., description='Configuration type') 32 | ) 33 | default: Optional[bool] = False 34 | 35 | 36 | @router.get("/config") 37 | async def api_get_config(config: GetConfig = Depends()): 38 | log.debug('API GET CONFIG: Request') 39 | # TZ is special 40 | if config.type == "tz": 41 | config = get_tz_config(refresh=config.default) 42 | return JSONResponse(content=config) 43 | 44 | # Otherwise handle other config types 45 | if config.default: 46 | default_config = get(f"{URL_WILLOW_CONFIG}?type={config.type}").json() 47 | if isinstance(default_config, dict): 48 | return default_config 49 | else: 50 | raise HTTPException(status_code=400, detail="Invalid default config") 51 | 52 | if config.type == "nvs": 53 | nvs = get_nvs_db() 54 | return JSONResponse(content=nvs) 55 | elif config.type == "config": 56 | config = get_config_db() 57 | if "wis_tts_url_v2" in config: 58 | config["wis_tts_url"] = sub("[&?]text=", "", config["wis_tts_url_v2"]) 59 | del config["wis_tts_url_v2"] 60 | return JSONResponse(content=config) 61 | elif config.type == "ha_token": 62 | config = get_config_db() 63 | return PlainTextResponse(config["hass_token"]) 64 | elif config.type == "ha_url": 65 | config = get_config_db() 66 | url = construct_url(config["hass_host"], config["hass_port"], config["hass_tls"]) 67 | return PlainTextResponse(url) 68 | elif config.type == "multinet": 69 | config = get_multinet() 70 | return JSONResponse(content=config) 71 | elif config.type == "was": 72 | config = get_was_config() 73 | return JSONResponse(content=config) 74 | 75 | 76 | class PostConfig(BaseModel): 77 | type: Literal['config', 'nvs', 'was'] = Field(Query(..., description='Configuration type')) 78 | apply: bool = Field(Query(..., description='Apply configuration to device')) 79 | 80 | 81 | @router.post("/config") 82 | async def api_post_config(request: Request, config: PostConfig = Depends()): 83 | log.debug('API POST CONFIG: Request') 84 | if config.type == "config": 85 | await post_config(request, config.apply) 86 | init_command_endpoint(request.app) 87 | elif config.type == "nvs": 88 | await post_nvs(request, config.apply) 89 | elif config.type == "was": 90 | await post_was(request, config.apply) 91 | -------------------------------------------------------------------------------- /app/internal/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Optional 3 | 4 | from pydantic import BaseModel, ConfigDict 5 | 6 | 7 | class WillowAudioCodec(str, Enum): 8 | AMR_WB = 'AMR-WB' 9 | PCM = 'PCM' 10 | 11 | 12 | class WillowAudioResponseType(str, Enum): 13 | Chimes = 'Chimes' 14 | none = 'None' 15 | TTS = 'TTS' 16 | 17 | 18 | class WillowCommandEndpoint(str, Enum): 19 | HomeAssistant = 'Home Assistant' 20 | openHAB = 'openHAB' 21 | MQTT = 'MQTT' 22 | REST = 'REST' 23 | 24 | 25 | class WillowMqttAuthType(str, Enum): 26 | none = 'none' 27 | userpw = 'userpw' 28 | 29 | 30 | class WillowNtpConfig(str, Enum): 31 | Host = 'Host' 32 | DHCP = 'DHCP' 33 | 34 | 35 | class WillowRestAuthType(str, Enum): 36 | none = 'None' 37 | Basic = 'Basic' 38 | Header = 'Header' 39 | 40 | 41 | class WillowSpeechRecMode(str, Enum): 42 | WIS = 'WIS' 43 | 44 | 45 | class WillowWakeMode(str, Enum): 46 | _1CH_90 = '1CH_90' 47 | _1CH_95 = '1CH_95' 48 | _2CH_90 = '2CH_90' 49 | _2CH_95 = '2CH_95' 50 | _3CH_90 = '3CH_90' 51 | _3CH_95 = '3CH_95' 52 | 53 | 54 | class WillowWakeWord(str, Enum): 55 | alexa = 'alexa' 56 | hiesp = 'hiesp' 57 | hilexin = 'hilexin' 58 | 59 | 60 | class WillowConfig(BaseModel, validate_assignment=True): 61 | aec: bool = None 62 | audio_codec: WillowAudioCodec = None 63 | audio_response_type: WillowAudioResponseType = None 64 | bss: bool = None 65 | command_endpoint: WillowCommandEndpoint = None 66 | display_timeout: int = None 67 | hass_host: Optional[str] = None 68 | hass_port: Optional[int] = None 69 | hass_tls: Optional[bool] = None 70 | hass_token: Optional[str] = None 71 | lcd_brightness: int = None 72 | mic_gain: int = None 73 | mqtt_auth_type: Optional[WillowMqttAuthType] = None 74 | mqtt_host: Optional[str] = None 75 | mqtt_password: Optional[str] = None 76 | mqtt_port: Optional[int] = None 77 | mqtt_tls: Optional[bool] = None 78 | mqtt_topic: Optional[str] = None 79 | mqtt_username: Optional[str] = None 80 | multiwake: bool = None 81 | ntp_config: WillowNtpConfig = None 82 | ntp_host: Optional[str] = None 83 | openhab_token: Optional[str] = None 84 | openhab_url: Optional[str] = None 85 | record_buffer: int = None 86 | rest_auth_header: Optional[str] = None 87 | rest_auth_pass: Optional[str] = None 88 | rest_auth_type: Optional[WillowRestAuthType] = None 89 | rest_auth_user: Optional[str] = None 90 | rest_url: Optional[str] = None 91 | show_prereleases: bool = None 92 | speaker_volume: int = None 93 | speech_rec_mode: Optional[WillowSpeechRecMode] = None 94 | stream_timeout: int = None 95 | timezone: str = None 96 | timezone_name: str = None 97 | vad_mode: int = None 98 | vad_timeout: int = None 99 | wake_confirmation: bool = None 100 | wake_mode: WillowWakeMode = None 101 | wake_word: WillowWakeWord = None 102 | was_mode: bool = None 103 | wis_tts_url: Optional[str] = None 104 | wis_tts_url_v2: Optional[str] = None 105 | wis_url: str = None 106 | 107 | # use Enum strings instead of e.g. WillowAudioCodec.PCM 108 | model_config = ConfigDict(use_enum_values=True) 109 | 110 | 111 | class WillowNvsWas(BaseModel): 112 | URL: str = None 113 | 114 | 115 | class WillowNvsWifi(BaseModel): 116 | PSK: str = None 117 | SSID: str = None 118 | 119 | 120 | class WillowNvsConfig(BaseModel): 121 | WAS: WillowNvsWas = None 122 | WIFI: WillowNvsWifi = None 123 | -------------------------------------------------------------------------------- /app/routers/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from logging import getLogger 4 | from typing import Literal 5 | 6 | from fastapi import APIRouter, Depends, Query, Request 7 | from fastapi.responses import JSONResponse 8 | from pydantic import BaseModel, Field 9 | 10 | from app.db.main import get_devices_db, save_client_config_to_db 11 | 12 | from ..internal.was import device_command, warm_tts 13 | 14 | 15 | log = getLogger("WAS") 16 | router = APIRouter( 17 | prefix="/api", 18 | ) 19 | 20 | 21 | @router.get("/client") 22 | async def api_get_client(request: Request): 23 | log.debug('API GET CLIENT: Request') 24 | devices = get_devices_db() 25 | clients = [] 26 | macs = [] 27 | labels = {} 28 | 29 | # This is ugly but it provides a combined response 30 | for ws, client in request.app.connmgr.connected_clients.items(): 31 | if client.mac_addr not in macs: 32 | labels.update({client.mac_addr: None}) 33 | for device in devices: 34 | if device["mac_addr"] == client.mac_addr: 35 | if device["label"]: 36 | labels.update({client.mac_addr: device["label"]}) 37 | version = client.ua.replace("Willow/", "") 38 | clients.append({ 39 | 'hostname': client.hostname, 40 | 'platform': client.platform, 41 | 'mac_addr': client.mac_addr, 42 | 'ip': ws.client.host, 43 | 'port': ws.client.port, 44 | 'version': version, 45 | 'label': labels[client.mac_addr] 46 | }) 47 | macs.append(client.mac_addr) 48 | 49 | # Sort connected clients by label if we have it 50 | # If all devices don't have labels we fall back to sorting by hostname 51 | try: 52 | sorted_clients = sorted(clients, key=lambda x: x['label']) 53 | except Exception: 54 | sorted_clients = sorted(clients, key=lambda x: x['hostname']) 55 | 56 | return JSONResponse(content=sorted_clients) 57 | 58 | 59 | class PostClient(BaseModel): 60 | action: Literal['restart', 'update', 'config', 'identify', 'notify'] = Field( 61 | Query(..., description='Client action') 62 | ) 63 | 64 | 65 | @router.post("/client") 66 | async def api_post_client(request: Request, device: PostClient = Depends()): 67 | log.debug('API POST CLIENT: Request') 68 | data = await request.json() 69 | 70 | if device.action == "update": 71 | msg = json.dumps({'cmd': 'ota_start', 'ota_url': data["ota_url"]}) 72 | try: 73 | ws = request.app.connmgr.get_client_by_hostname(data["hostname"]) 74 | await ws.send_text(msg) 75 | except Exception as e: 76 | log.error(f"Failed to trigger OTA ({e})") 77 | finally: 78 | return 79 | elif device.action == "config": 80 | devices = get_devices_db() 81 | new = True 82 | 83 | for i, device in enumerate(devices): 84 | if device.get("mac_addr") == data['mac_addr']: 85 | new = False 86 | devices[i] = data 87 | break 88 | 89 | if new and len(data['mac_addr']) > 0: 90 | devices.append(data) 91 | 92 | save_client_config_to_db(devices) 93 | elif device.action == 'notify': 94 | log.debug(f"received notify command on API: {data}") 95 | warm_tts(data["data"]) 96 | request.app.notify_queue.add(data) 97 | else: 98 | # Catch all assuming anything else is a device command 99 | return await device_command(request.app.connmgr, data, device.action) 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # WAS 163 | static/ 164 | storage 165 | 166 | # Web UI working directory and others 167 | work 168 | -------------------------------------------------------------------------------- /utils.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | WAS_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 4 | cd "$WAS_DIR" 5 | 6 | # Test for local environment file and use any overrides 7 | if [ -r .env ]; then 8 | echo "Using configuration overrides from .env file" 9 | . .env 10 | else 11 | echo "Using default configuration values" 12 | touch .env 13 | fi 14 | 15 | #Import source the .env file 16 | set -a 17 | source .env 18 | 19 | # Which docker image to run 20 | IMAGE=${IMAGE:-willow-application-server} 21 | 22 | # UI Listen port 23 | UI_LISTEN_PORT=${UI_LISTEN_PORT:-8501} 24 | 25 | # API Listen Port 26 | API_LISTEN_PORT=${API_LISTEN_PORT:-8502} 27 | 28 | # Log level - acceptable values are debug, info, warning, error, critical. Suggest info or debug. 29 | LOG_LEVEL=${LOG_LEVEL:-info} 30 | 31 | # Listen IP 32 | LISTEN_IP=${LISTEN_IP:-0.0.0.0} 33 | 34 | TAG=${TAG:-latest} 35 | 36 | # Torture delay 37 | TORTURE_DELAY=${TORTURE_DELAY:-300} 38 | 39 | # Web ui branch 40 | WEB_UI_BRANCH="main" 41 | 42 | # Local working directory for web ui 43 | WEB_UI_DIR="willow-application-server-ui" 44 | 45 | # Web ui URL 46 | WEB_UI_URL="https://github.com/HeyWillow/willow-application-server-ui.git" 47 | 48 | # Reachable WAS IP for the "default" interface 49 | WAS_IP=$(ip route get 1.1.1.1 | grep -oP 'src \K\S+') 50 | 51 | # Get WAS version 52 | export WAS_VERSION=$(git describe --always --dirty --tags) 53 | 54 | set +a 55 | 56 | if [ -z "$WAS_IP" ]; then 57 | echo "Could not determine WAS IP address - you will need to add it to .env" 58 | exit 1 59 | else 60 | echo "WAS Web UI URL is http://$WAS_IP:$API_LISTEN_PORT" 61 | fi 62 | 63 | set +a 64 | 65 | dep_check() { 66 | return 67 | } 68 | 69 | freeze_requirements() { 70 | if [ ! -f /.dockerenv ]; then 71 | echo "This script is meant to be run inside the container - exiting" 72 | exit 1 73 | fi 74 | 75 | # Freeze 76 | pip freeze > requirements.txt 77 | } 78 | 79 | build-docker() { 80 | docker build --build-arg "WAS_VERSION=$WAS_VERSION" -t "$IMAGE":"$TAG" . 81 | } 82 | 83 | build-web-ui() { 84 | mkdir -p "$WAS_DIR"/work 85 | cd "$WAS_DIR"/work 86 | if [ -d "$WEB_UI_DIR/node_modules" ]; then 87 | echo "Existing web ui working dir found, we need sudo to remove it because of docker" 88 | sudo rm -rf willow-application-server-ui 89 | fi 90 | git clone "$WEB_UI_URL" 91 | cd willow-application-server-ui 92 | git checkout "$WEB_UI_BRANCH" 93 | ./utils.sh build-docker 94 | ./utils.sh install 95 | # WAS_DIR is already set 96 | export WAS_DIR 97 | ./utils.sh build 98 | } 99 | 100 | shell() { 101 | docker run -it -v $WAS_DIR:/app -v $WAS_DIR/cache:/root/.cache -v willow-application-server_was-storage:/app/storage "$IMAGE":"$TAG" \ 102 | /usr/bin/env bash 103 | } 104 | 105 | case $1 in 106 | 107 | build-docker|build) 108 | build-docker 109 | ;; 110 | 111 | build-web-ui) 112 | build-web-ui 113 | ;; 114 | 115 | freeze-requirements) 116 | freeze_requirements 117 | ;; 118 | 119 | start|run|up) 120 | dep_check 121 | shift 122 | docker compose up --remove-orphans "$@" 123 | ;; 124 | 125 | stop|down) 126 | dep_check 127 | shift 128 | docker compose down "$@" 129 | ;; 130 | 131 | shell|docker) 132 | shell 133 | ;; 134 | 135 | test) 136 | dep_check 137 | docker run --rm -it --env PYTHONPATH=/app --volume="${WAS_DIR}:/app" "$IMAGE":"$TAG" pytest 138 | ;; 139 | 140 | torture) 141 | echo "Starting WAS device torture test" 142 | docker compose down 143 | while true; do 144 | docker compose up -d 145 | echo "Sleeping for $TORTURE_DELAY" 146 | sleep $TORTURE_DELAY 147 | docker compose down 148 | "Sleeping for $TORTURE_DELAY" 149 | sleep $TORTURE_DELAY 150 | done 151 | ;; 152 | 153 | *) 154 | dep_check 155 | echo "Passing unknown argument directly to docker compose" 156 | docker compose "$@" 157 | ;; 158 | 159 | esac 160 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = migrations/ 6 | 7 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s 8 | # Uncomment the line below if you want the files to be prepended with date and time 9 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file 10 | # for all available tokens 11 | # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s 12 | 13 | # sys.path path, will be prepended to sys.path if present. 14 | # defaults to the current working directory. 15 | prepend_sys_path = . 16 | 17 | # timezone to use when rendering the date within the migration file 18 | # as well as the filename. 19 | # If specified, requires the python-dateutil library that can be 20 | # installed by adding `alembic[tz]` to the pip requirements 21 | # string value is passed to dateutil.tz.gettz() 22 | # leave blank for localtime 23 | # timezone = 24 | 25 | # max length of characters to apply to the 26 | # "slug" field 27 | # truncate_slug_length = 40 28 | 29 | # set to 'true' to run the environment during 30 | # the 'revision' command, regardless of autogenerate 31 | # revision_environment = false 32 | 33 | # set to 'true' to allow .pyc and .pyo files without 34 | # a source .py file to be detected as revisions in the 35 | # versions/ directory 36 | # sourceless = false 37 | 38 | # version location specification; This defaults 39 | # to migrations//versions. When using multiple version 40 | # directories, initial revisions must be specified with --version-path. 41 | # The path separator used here should be the separator specified by "version_path_separator" below. 42 | # version_locations = %(here)s/bar:%(here)s/bat:migrations//versions 43 | 44 | # version path separator; As mentioned above, this is the character used to split 45 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 46 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 47 | # Valid values for version_path_separator are: 48 | # 49 | # version_path_separator = : 50 | # version_path_separator = ; 51 | # version_path_separator = space 52 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 53 | 54 | # set to 'true' to search source files recursively 55 | # in each "version_locations" directory 56 | # new in Alembic version 1.10 57 | # recursive_version_locations = false 58 | 59 | # the output encoding used when revision files 60 | # are written from script.py.mako 61 | # output_encoding = utf-8 62 | 63 | # sqlalchemy.url = sqlite:///storage/was.db 64 | 65 | 66 | [post_write_hooks] 67 | # post_write_hooks defines scripts or Python functions that are run 68 | # on newly generated revision scripts. See the documentation for further 69 | # detail and examples 70 | 71 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 72 | # hooks = black 73 | # black.type = console_scripts 74 | # black.entrypoint = black 75 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 76 | 77 | # lint with attempts to fix using "ruff" - use the exec runner, execute a binary 78 | # hooks = ruff 79 | # ruff.type = exec 80 | # ruff.executable = %(here)s/.venv/bin/ruff 81 | # ruff.options = --fix REVISION_SCRIPT_FILENAME 82 | 83 | # Logging configuration 84 | [loggers] 85 | keys = root,sqlalchemy,alembic 86 | 87 | [handlers] 88 | keys = console 89 | 90 | [formatters] 91 | keys = generic 92 | 93 | [logger_root] 94 | level = WARN 95 | handlers = console 96 | qualname = 97 | 98 | [logger_sqlalchemy] 99 | level = WARN 100 | handlers = 101 | qualname = sqlalchemy.engine 102 | 103 | [logger_alembic] 104 | level = INFO 105 | handlers = 106 | qualname = alembic 107 | 108 | [handler_console] 109 | class = StreamHandler 110 | args = (sys.stderr,) 111 | level = NOTSET 112 | formatter = generic 113 | 114 | [formatter_generic] 115 | format = %(levelname)-5.5s [%(name)s] %(message)s 116 | datefmt = %H:%M:%S 117 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/mqtt.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import paho.mqtt.client as mqtt 5 | from . import ( 6 | CommandEndpoint, 7 | CommandEndpointConfigException, 8 | CommandEndpointResponse, 9 | CommandEndpointResult, 10 | CommandEndpointRuntimeException 11 | ) 12 | from enum import Enum 13 | 14 | 15 | class MqttAuthType(Enum): 16 | NONE = 1 17 | USERPW = 2 18 | 19 | 20 | class MqttConfig: 21 | auth_type: Enum = MqttAuthType.NONE 22 | hostname: str = None 23 | password: str = None 24 | port: int = 8883 25 | tls: bool = True 26 | topic: str = None 27 | username: str = None 28 | 29 | log = logging.getLogger("WAS") 30 | 31 | def set_auth_type(self, auth_type=MqttAuthType.NONE): 32 | self.log.debug(f"setting auth type: {auth_type}") 33 | self.auth_type = MqttAuthType[auth_type.upper()] 34 | 35 | def set_hostname(self, hostname=None): 36 | self.hostname = hostname 37 | 38 | def set_password(self, password=None): 39 | self.password = password 40 | 41 | def set_port(self, port=8883): 42 | self.port = port 43 | 44 | def set_tls(self, tls=True): 45 | self.tls = tls 46 | 47 | def set_topic(self, topic=None): 48 | self.topic = topic 49 | 50 | def set_username(self, username=None): 51 | self.username = username 52 | 53 | def validate(self): 54 | if self.auth_type == MqttAuthType.USERPW: 55 | if self.password is None: 56 | raise CommandEndpointConfigException("User/Password auth enabled without password") 57 | if self.username is None: 58 | raise CommandEndpointConfigException("User/Password auth enabled without password") 59 | 60 | 61 | class MqttEndpoint(CommandEndpoint): 62 | name = "MQTT" 63 | 64 | def __init__(self, config): 65 | self.config = config 66 | self.config.validate() 67 | self.connected = False 68 | self.mqtt_client = None 69 | 70 | loop = asyncio.get_event_loop() 71 | self.task = loop.create_task(self.connect()) 72 | 73 | async def connect(self): 74 | try: 75 | self.mqtt_client = mqtt.Client() 76 | self.mqtt_client.on_connect = self.cb_connect 77 | self.mqtt_client.on_disconnect = self.cb_disconnect 78 | self.mqtt_client.on_msg = self.cb_msg 79 | if self.config.username is not None and self.config.password is not None: 80 | self.mqtt_client.username_pw_set(self.config.username, self.config.password) 81 | if self.config.tls: 82 | self.mqtt_client.tls_set() 83 | self.mqtt_client.connect_async(self.config.hostname, self.config.port, 60) 84 | self.mqtt_client.loop_start() 85 | except Exception as e: 86 | self.log.info(f"{self.name}: exception occurred: {e}") 87 | await asyncio.sleep(1) 88 | 89 | def cb_connect(self, client, userdata, flags, rc): 90 | self.connected = True 91 | self.log.info("MQTT connected") 92 | client.subscribe(self.config.topic) 93 | 94 | def cb_disconnect(self, client, userdata, rc): 95 | self.connected = False 96 | self.log.info("MQTT disconnected") 97 | 98 | def cb_msg(self, client, userdata, msg): 99 | self.log.info(f"cb_msg: topic={msg.topic} payload={msg.payload}") 100 | 101 | def parse_response(self, response): 102 | res = CommandEndpointResult() 103 | if response.ok: 104 | res.ok = True 105 | if len(res.speech) > 0: 106 | res.speech = response.text 107 | else: 108 | res.speech = "Success!" 109 | 110 | command_endpoint_response = CommandEndpointResponse(result=res) 111 | return command_endpoint_response.model_dump_json() 112 | 113 | def send(self, data=None, jsondata=None, ws=None, client=None): 114 | if not self.connected: 115 | raise CommandEndpointRuntimeException(f"{self.name} not connected") 116 | try: 117 | if jsondata is not None: 118 | self.mqtt_client.publish(self.config.topic, payload=json.dumps(jsondata)) 119 | else: 120 | self.mqtt_client.publish(self.config.topic, payload=data) 121 | except Exception as e: 122 | raise CommandEndpointRuntimeException(e) 123 | -------------------------------------------------------------------------------- /app/internal/notify.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | 5 | from logging import getLogger 6 | from pydantic import BaseModel, ConfigDict, Field 7 | from typing import Annotated, Dict, List, Optional 8 | 9 | from .connmgr import ConnMgr 10 | 11 | 12 | log = getLogger("WAS") 13 | 14 | 15 | class NotifyData(BaseModel): 16 | model_config = ConfigDict(extra="forbid") 17 | audio_url: Optional[str] = None 18 | backlight: bool = False 19 | backlight_max: bool = False 20 | cancel: bool = False 21 | id: int = -1 22 | repeat: int = 1 23 | strobe_period_ms: Optional[int] = 0 24 | text: Optional[str] = None 25 | volume: Optional[int] = Optional[Annotated[int, Field(ge=0, le=100)]] 26 | 27 | 28 | class NotifyMsg(BaseModel): 29 | model_config = ConfigDict(extra="forbid") 30 | 31 | cmd: str = "notify" 32 | data: NotifyData 33 | hostname: Optional[str] = None 34 | 35 | 36 | class NotifyQueue(BaseModel): 37 | model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") 38 | 39 | connmgr: ConnMgr = None 40 | notifications: Dict[str, List[NotifyData]] = {} 41 | task: asyncio.Task = None 42 | 43 | def start(self): 44 | loop = asyncio.get_event_loop() 45 | self.task = loop.create_task(self.dequeue()) 46 | 47 | def add(self, msg): 48 | msg = NotifyMsg.model_validate_json(json.dumps(msg)) 49 | if not hasattr(msg.data, "id") or msg.data.id < 0: 50 | msg.data.id = int(time.time() * 1000) 51 | 52 | log.debug(msg) 53 | 54 | if msg.hostname is not None: 55 | mac_addr = self.connmgr.get_mac_by_hostname(msg.hostname) 56 | if mac_addr == "unknown": 57 | log.warn(f"no MAC address found for {msg.hostname}, skipping notification") 58 | return 59 | if mac_addr in self.notifications: 60 | self.notifications[mac_addr].append(msg.data) 61 | else: 62 | self.notifications.update({mac_addr: [msg.data]}) 63 | 64 | else: 65 | for _, client in self.connmgr.connected_clients.items(): 66 | if client.mac_addr == "unknown": 67 | log.warn(f"no MAC address found for {client.hostname}, skipping") 68 | continue 69 | if client.mac_addr in self.notifications: 70 | self.notifications[client.mac_addr].append(msg.data) 71 | else: 72 | self.notifications.update({client.mac_addr: [msg.data]}) 73 | 74 | def done(self, ws, id): 75 | client = self.connmgr.get_client_by_ws(ws) 76 | for i, notification in enumerate(self.notifications[client.mac_addr]): 77 | if notification.id == id: 78 | self.connmgr.set_notification_active(ws, 0) 79 | self.notifications[client.mac_addr].pop(i) 80 | break 81 | 82 | data = NotifyData(id=id, cancel=True) 83 | # explicitly set cmd so we can use exclude_unset 84 | msg_cancel = NotifyMsg(cmd="notify", data=data) 85 | log.info(msg_cancel) 86 | asyncio.ensure_future(self.connmgr.broadcast(msg_cancel.model_dump_json(exclude_unset=True))) 87 | 88 | async def dequeue(self): 89 | while True: 90 | try: 91 | for mac_addr, notifications in self.notifications.items(): 92 | # log.debug(f"dequeueing notifications for {mac_addr}: {notifications} (len={len(notifications)})") 93 | if len(notifications) > 0: 94 | ws = self.connmgr.get_ws_by_mac(mac_addr) 95 | if ws is None: 96 | continue 97 | if self.connmgr.is_notification_active(ws): 98 | log.debug(f"{mac_addr} has active notification") 99 | continue 100 | 101 | for i, notification in enumerate(notifications): 102 | if notification.id > int(time.time() * 1000): 103 | continue 104 | elif notification.id < int((time.time() - 3600) * 1000): 105 | # TODO should we make this configurable ? 106 | # or at least use a constant and reject notifications with old ID in the API 107 | log.warning("expiring notification older than 1h") 108 | notifications.pop(i) 109 | 110 | self.connmgr.set_notification_active(ws, notification.id) 111 | log.debug(f"dequeueing notification for {mac_addr}: {notification}") 112 | msg = NotifyMsg(data=notification) 113 | asyncio.ensure_future(ws.send_text( 114 | msg.model_dump_json(exclude={'hostname'}, exclude_none=True)) 115 | ) 116 | # don't send more than one notification at once 117 | break 118 | except Exception as e: 119 | log.debug(f"exception during dequeue: {e}") 120 | 121 | await asyncio.sleep(1) 122 | -------------------------------------------------------------------------------- /app/internal/command_endpoints/ha_ws.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import requests 4 | import time 5 | import websockets 6 | from . import ( 7 | CommandEndpoint, 8 | CommandEndpointResponse, 9 | CommandEndpointResult, 10 | CommandEndpointRuntimeException, 11 | ) 12 | 13 | 14 | class HomeAssistantWebSocketEndpoint(CommandEndpoint): 15 | name = "WAS Home Assistant WebSocket Endpoint" 16 | 17 | connmap = {} 18 | 19 | def __init__(self, app, host, port, tls, token): 20 | 21 | self.app = app 22 | self.host = host 23 | self.port = port 24 | self.token = token 25 | self.tls = tls 26 | self.url = self.construct_url(ws=True) 27 | 28 | self.ha_willow_devices = {} 29 | self.ha_willow_devices_request_id = None 30 | self.haws = None 31 | 32 | loop = asyncio.get_event_loop() 33 | self.task = loop.create_task(self.connect()) 34 | 35 | def construct_url(self, ws): 36 | ha_url_scheme = "" 37 | if ws: 38 | ha_url_scheme = "wss://" if self.tls else "ws://" 39 | else: 40 | ha_url_scheme = "https://" if self.tls else "http://" 41 | 42 | return f"{ha_url_scheme}{self.host}:{self.port}" 43 | 44 | async def connect(self): 45 | while True: 46 | try: 47 | # deflate compression is enabled by default, making tcpdump difficult 48 | async with websockets.connect(f"{self.url}/api/websocket", compression=None) as self.haws: 49 | while True: 50 | msg = await self.haws.recv() 51 | await self.cb_msg(msg) 52 | except Exception as e: 53 | self.log.info(f"{self.name}: exception occurred: {e}") 54 | await asyncio.sleep(1) 55 | 56 | async def cb_msg(self, msg): 57 | self.log.debug(f"haws_cb: {self.app} {msg}") 58 | msg = json.loads(msg) 59 | if "type" in msg: 60 | if msg["type"] == "event": 61 | if msg["event"]["type"] == "intent-end": 62 | id = int(msg["id"]) 63 | ws = self.connmap[id] 64 | out = CommandEndpointResult() 65 | response_type = msg["event"]["data"]["intent_output"]["response"]["response_type"] 66 | if response_type == "action_done": 67 | out.ok = True 68 | response = msg["event"]["data"]["intent_output"]["response"] 69 | # Not all intents return speech (e.g. HassNeverMind) 70 | if 'plain' in response["speech"]: 71 | out.speech = response["speech"]["plain"]["speech"] 72 | else: 73 | out.speech = "" 74 | command_endpoint_response = CommandEndpointResponse(result=out) 75 | self.log.debug(f"sending {command_endpoint_response} to {ws}") 76 | asyncio.ensure_future(ws.send_text(command_endpoint_response.model_dump_json())) 77 | self.connmap.pop(id) 78 | elif msg["type"] == "auth_required": 79 | auth_msg = { 80 | "type": "auth", 81 | "access_token": self.token, 82 | } 83 | self.log.debug(f"authenticating HA WebSocket connection: {auth_msg}") 84 | await self.haws.send(json.dumps(auth_msg)) 85 | elif msg["type"] == "auth_ok": 86 | self.ha_willow_devices_request_id = self.next_id() 87 | msg = { 88 | "type": "config/device_registry/list", 89 | "id": self.ha_willow_devices_request_id 90 | } 91 | self.log.debug(f"fetching devices: {msg}") 92 | await self.haws.send(json.dumps(msg)) 93 | elif msg["type"] == "result" and msg["success"]: 94 | if msg["id"] == self.ha_willow_devices_request_id: 95 | devices = msg["result"] 96 | self.ha_willow_devices = { 97 | ident[1]: item["id"] 98 | for item in devices 99 | for ident in item.get("identifiers", []) 100 | if ident[0] == "willow" 101 | } 102 | self.log.debug(f"received willow devics: {self.ha_willow_devices}") 103 | 104 | def parse_response(self, response): 105 | return None 106 | 107 | def next_id(self): 108 | return int(time.monotonic_ns()) 109 | 110 | def send(self, jsondata, ws, client=None): 111 | id = self.next_id() 112 | 113 | if id not in self.connmap: 114 | self.connmap[id] = ws 115 | 116 | if "language" in jsondata: 117 | jsondata.pop("language") 118 | 119 | out = { 120 | 'end_stage': 'intent', 121 | 'id': id, 122 | 'input': jsondata, 123 | 'start_stage': 'intent', 124 | 'type': 'assist_pipeline/run', 125 | } 126 | 127 | if client.mac_addr in self.ha_willow_devices: 128 | self.log.info("HA has a registered device for this willow satellite") 129 | out["device_id"] = self.ha_willow_devices[client.mac_addr] 130 | 131 | self.log.debug(f"sending to HA WS: {out}") 132 | asyncio.ensure_future(self.haws.send(json.dumps(out))) 133 | 134 | def stop(self): 135 | self.log.info(f"stopping {self.name}") 136 | self.task.cancel() 137 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | 5 | import alembic 6 | import alembic.config 7 | from fastapi import ( 8 | FastAPI, 9 | Header, 10 | WebSocket, 11 | WebSocketDisconnect, 12 | ) 13 | from contextlib import asynccontextmanager 14 | from fastapi.responses import RedirectResponse 15 | from fastapi.staticfiles import StaticFiles 16 | import logging 17 | from pathlib import Path 18 | from shutil import move 19 | from typing import Annotated 20 | from websockets.exceptions import ConnectionClosed 21 | from fastapi.middleware.cors import CORSMiddleware 22 | 23 | from app.const import ( 24 | ALEMBIC_CONFIG, 25 | DIR_OTA, 26 | STORAGE_USER_CLIENT_CONFIG, 27 | STORAGE_USER_CONFIG, 28 | STORAGE_USER_NVS, 29 | ) 30 | 31 | from app.db.main import get_config_db, migrate_user_client_config, migrate_user_config, migrate_user_nvs 32 | from app.internal.command_endpoints import ( 33 | CommandEndpointResponse, 34 | CommandEndpointResult, 35 | CommandEndpointRuntimeException 36 | ) 37 | from app.internal.command_endpoints.main import init_command_endpoint 38 | from app.internal.was import ( 39 | build_msg, 40 | get_config, 41 | get_devices, 42 | get_nvs, 43 | get_tz_config, 44 | ) 45 | from app.settings import get_settings 46 | 47 | from .internal.client import Client 48 | from .internal.connmgr import ConnMgr 49 | from .internal.notify import NotifyQueue 50 | from .internal.wake import WakeEvent, WakeSession 51 | from .routers import asset 52 | from .routers import client 53 | from .routers import config 54 | from .routers import info 55 | from .routers import ota 56 | from .routers import release 57 | from .routers import status 58 | 59 | 60 | logging.basicConfig( 61 | format='%(asctime)s %(levelname)-8s %(message)s', 62 | level=logging.INFO, 63 | datefmt='%Y-%m-%d %H:%M:%S') 64 | 65 | log = logging.getLogger("WAS") 66 | try: 67 | log.setLevel(os.environ.get("WAS_LOG_LEVEL").upper()) 68 | except Exception: 69 | pass 70 | 71 | settings = get_settings() 72 | 73 | def db_migrations(): 74 | cfg = alembic.config.Config(ALEMBIC_CONFIG) 75 | cfg.attributes['logger'] = log 76 | alembic.command.upgrade(cfg, "head") 77 | 78 | 79 | @asynccontextmanager 80 | async def lifespan(app: FastAPI): 81 | # database schema migrations 82 | db_migrations() 83 | 84 | migrate_user_files() 85 | get_tz_config(refresh=True) 86 | 87 | user_config = get_config() 88 | # skip migration if user_config is empty 89 | if user_config: 90 | try: 91 | migrate_user_config(user_config) 92 | os.remove(STORAGE_USER_CONFIG) 93 | except Exception as e: 94 | log.error(f"failed to migrate user config to database: {e}") 95 | 96 | user_nvs = get_nvs() 97 | # skip migration if user_nvs is empty 98 | if user_nvs: 99 | try: 100 | migrate_user_nvs(user_nvs) 101 | os.remove(STORAGE_USER_NVS) 102 | except Exception as e: 103 | log.error(f"failed to migrate user nvs to database: {e}") 104 | 105 | devices = get_devices() 106 | # skip migration if devices is empty 107 | if devices: 108 | try: 109 | migrate_user_client_config(devices) 110 | os.remove(STORAGE_USER_CLIENT_CONFIG) 111 | except Exception as e: 112 | log.error(f"failed to migrate user client config to database: {e}") 113 | app.connmgr = ConnMgr() 114 | 115 | app.command_endpoint = None 116 | try: 117 | init_command_endpoint(app) 118 | except Exception as e: 119 | log.error(f"failed to initialize command endpoint ({e})") 120 | 121 | app.notify_queue = NotifyQueue(connmgr=app.connmgr) 122 | app.notify_queue.start() 123 | 124 | yield 125 | log.info("shutting down") 126 | 127 | app = FastAPI(title="Willow Application Server", 128 | description="Willow Management API", 129 | openapi_url="/openapi.json", 130 | docs_url="/docs", 131 | lifespan=lifespan, 132 | redoc_url="/redoc", 133 | version=settings.was_version) 134 | 135 | wake_session = None 136 | 137 | app.add_middleware( 138 | CORSMiddleware, 139 | allow_origins=["*"], 140 | allow_credentials=True, 141 | allow_methods=["*"], 142 | allow_headers=["*"], 143 | ) 144 | 145 | 146 | def migrate_user_files(): 147 | for user_file in ['user_config.json', 'user_multinet.json', 'user_nvs.json']: 148 | if os.path.isfile(user_file): 149 | dest = f"storage/{user_file}" 150 | if not os.path.isfile(dest): 151 | move(user_file, dest) 152 | 153 | 154 | def hex_mac(mac): 155 | if isinstance(mac, list): 156 | mac = '%02x:%02x:%02x:%02x:%02x:%02x' % (mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]) 157 | return mac 158 | 159 | 160 | # Make sure we always have DIR_OTA 161 | Path(DIR_OTA).mkdir(parents=True, exist_ok=True) 162 | 163 | 164 | app.mount("/admin", StaticFiles(directory="static/admin", html=True), name="admin") 165 | 166 | 167 | @app.get("/", response_class=RedirectResponse) 168 | def api_redirect_admin(): 169 | log.debug('API GET ROOT: Request') 170 | return "/admin" 171 | 172 | 173 | app.include_router(asset.router) 174 | app.include_router(client.router) 175 | app.include_router(config.router) 176 | app.include_router(info.router) 177 | app.include_router(ota.router) 178 | app.include_router(release.router) 179 | app.include_router(status.router) 180 | 181 | 182 | # WebSockets with params return 403 when done with APIRouter 183 | # https://github.com/tiangolo/fastapi/issues/98#issuecomment-1688632239 184 | @app.websocket("/ws") 185 | async def websocket_endpoint( 186 | websocket: WebSocket, 187 | user_agent: Annotated[str | None, Header(convert_underscores=True)] = None): 188 | client = Client(ua=user_agent) 189 | 190 | await app.connmgr.accept(websocket, client) 191 | try: 192 | while True: 193 | data = await websocket.receive_text() 194 | log.debug(str(data)) 195 | msg = json.loads(data) 196 | 197 | # latency sensitive so handle first 198 | if "wake_start" in msg: 199 | global wake_session 200 | if wake_session is not None: 201 | if wake_session.done: 202 | del wake_session 203 | wake_session = WakeSession() 204 | asyncio.create_task(wake_session.cleanup()) 205 | else: 206 | wake_session = WakeSession() 207 | asyncio.create_task(wake_session.cleanup()) 208 | 209 | if "wake_volume" in msg["wake_start"]: 210 | wake_event = WakeEvent(websocket, msg["wake_start"]["wake_volume"]) 211 | wake_session.add_event(wake_event) 212 | 213 | elif "wake_end" in msg: 214 | pass 215 | 216 | elif "notify_done" in msg: 217 | app.notify_queue.done(websocket, msg["notify_done"]) 218 | 219 | elif "cmd" in msg: 220 | if msg["cmd"] == "endpoint": 221 | if app.command_endpoint is not None: 222 | log.debug(f"Sending {msg['data']} to {app.command_endpoint.name}") 223 | try: 224 | resp = app.command_endpoint.send(jsondata=msg["data"], ws=websocket, client=client) 225 | if resp is not None: 226 | resp = app.command_endpoint.parse_response(resp) 227 | log.debug(f"Got response {resp} from endpoint") 228 | # HomeAssistantWebSocketEndpoint sends message via callback 229 | if resp is not None: 230 | asyncio.ensure_future(websocket.send_text(resp)) 231 | except CommandEndpointRuntimeException as e: 232 | command_endpoint_result = CommandEndpointResult(speech="WAS Command Endpoint unreachable") 233 | command_endpoint_response = CommandEndpointResponse(result=command_endpoint_result) 234 | asyncio.ensure_future(websocket.send_text(command_endpoint_response.model_dump_json())) 235 | log.error(f"WAS Command Endpoint unreachable: {e}") 236 | 237 | else: 238 | command_endpoint_result = CommandEndpointResult(speech="WAS Command Endpoint not active") 239 | command_endpoint_response = CommandEndpointResponse(result=command_endpoint_result) 240 | asyncio.ensure_future(websocket.send_text(command_endpoint_response.model_dump_json())) 241 | log.error("WAS Command Endpoint not active") 242 | 243 | elif msg["cmd"] == "get_config": 244 | asyncio.ensure_future(websocket.send_text(build_msg(get_config_db(), "config"))) 245 | 246 | elif "goodbye" in msg: 247 | app.connmgr.disconnect(websocket) 248 | 249 | elif "hello" in msg: 250 | if "hostname" in msg["hello"]: 251 | app.connmgr.update_client(websocket, "hostname", msg["hello"]["hostname"]) 252 | if "hw_type" in msg["hello"]: 253 | platform = msg["hello"]["hw_type"].upper() 254 | app.connmgr.update_client(websocket, "platform", platform) 255 | if "mac_addr" in msg["hello"]: 256 | mac_addr = hex_mac(msg["hello"]["mac_addr"]) 257 | app.connmgr.update_client(websocket, "mac_addr", mac_addr) 258 | 259 | except WebSocketDisconnect: 260 | app.connmgr.disconnect(websocket) 261 | except ConnectionClosed: 262 | app.connmgr.disconnect(websocket) 263 | except Exception as e: 264 | log.error(f"unhandled exception in WebSocket route: {e}") 265 | -------------------------------------------------------------------------------- /app/db/main.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | from sqlalchemy.exc import IntegrityError 4 | from sqlmodel import Session, create_engine, select 5 | 6 | from app.db.models import WillowClientTable, WillowConfigNamespaceType, WillowConfigTable, WillowConfigType 7 | from app.internal.config import WillowConfig, WillowNvsConfig, WillowNvsWas, WillowNvsWifi 8 | from app.settings import get_settings 9 | 10 | 11 | log = getLogger("WAS") 12 | 13 | settings = get_settings() 14 | 15 | connect_args = {} 16 | if settings.db_url.find("sqlite://") != -1: 17 | connect_args["check_same_thread"] = False 18 | 19 | engine = create_engine(settings.db_url, echo=False, connect_args=connect_args) 20 | 21 | 22 | def convert_str_or_none(input): 23 | """ Convert value to str or None 24 | 25 | We're saving values of different type in the config_value field of WillowConfigTable 26 | We therefore need to cast the value to str. 27 | Casting None to str would result in "None" being save in the database. 28 | This is a helper function to avoid that. 29 | """ 30 | if input is None: 31 | return None 32 | else: 33 | return str(input) 34 | 35 | 36 | def get_config_db(): 37 | config = WillowConfig() 38 | 39 | with Session(engine) as session: 40 | stmt = select(WillowConfigTable).where( 41 | WillowConfigTable.config_type == WillowConfigType.config, 42 | WillowConfigTable.config_value is not None 43 | ) 44 | records = session.exec(stmt) 45 | 46 | for record in records: 47 | setattr(config, record.config_name, record.config_value) 48 | 49 | return config.model_dump(exclude_none=True) 50 | 51 | 52 | def get_devices_db(): 53 | devices = [] 54 | with Session(engine) as session: 55 | stmt = select(WillowClientTable) 56 | records = session.exec(stmt) 57 | 58 | for record in records: 59 | devices.append(record.model_dump()) 60 | 61 | return devices 62 | 63 | 64 | def get_nvs_db(): 65 | config = WillowNvsConfig() 66 | config_was = WillowNvsWas() 67 | config_wifi = WillowNvsWifi() 68 | 69 | with Session(engine) as session: 70 | stmt = select(WillowConfigTable).where(WillowConfigTable.config_type == WillowConfigType.nvs) 71 | records = session.exec(stmt) 72 | 73 | for record in records: 74 | if record.config_namespace == WillowConfigNamespaceType.WAS: 75 | setattr(config_was, record.config_name, record.config_value) 76 | 77 | elif record.config_namespace == WillowConfigNamespaceType.WIFI: 78 | setattr(config_wifi, record.config_name, record.config_value) 79 | 80 | config.WAS = config_was 81 | config.WIFI = config_wifi 82 | 83 | return config.model_dump(exclude_none=True) 84 | 85 | 86 | def migrate_user_config(config): 87 | config = WillowConfig.parse_obj(config) 88 | log.debug(f"config: {config}") 89 | 90 | with Session(engine) as session: 91 | for k, v in iter(config): 92 | db_config = WillowConfigTable( 93 | config_type=WillowConfigType.config, 94 | config_name=k, 95 | config_value=convert_str_or_none(v) 96 | ) 97 | 98 | if v is not None: 99 | db_config.config_value = str(v) 100 | session.add(db_config) 101 | 102 | try: 103 | session.commit() 104 | except IntegrityError as e: 105 | # TODO avoid users thinking something is wrong here 106 | log.warning(e) 107 | session.rollback() 108 | 109 | 110 | def migrate_user_client_config(clients): 111 | log.debug(f"clients: {clients}") 112 | 113 | with Session(engine) as session: 114 | for client in iter(clients): 115 | db_client = WillowClientTable( 116 | label=client["label"], 117 | mac_addr=client["mac_addr"] 118 | ) 119 | session.add(db_client) 120 | 121 | try: 122 | session.commit() 123 | except IntegrityError as e: 124 | # TODO avoid users thinking something is wrong here 125 | log.warning(e) 126 | session.rollback() 127 | 128 | 129 | def migrate_user_nvs(config): 130 | config = WillowNvsConfig.parse_obj(config) 131 | log.debug(f"config: {config}") 132 | log.debug(f"WAS: {config.WAS}") 133 | log.debug(f"WIFI: {config.WIFI}") 134 | 135 | with Session(engine) as session: 136 | for k, v in iter(config.WAS): 137 | log.debug(f"WAS: k={k} v={v}") 138 | db_config = WillowConfigTable( 139 | config_type=WillowConfigType.nvs, 140 | config_name=k, 141 | config_namespace=WillowConfigNamespaceType.WAS, 142 | config_value=str(v), 143 | ) 144 | session.add(db_config) 145 | 146 | for k, v in iter(config.WIFI): 147 | log.debug(f"WIFI: k={k} v={v}") 148 | db_config = WillowConfigTable( 149 | config_type=WillowConfigType.nvs, 150 | config_name=k, 151 | config_namespace=WillowConfigNamespaceType.WIFI, 152 | config_value=str(v), 153 | ) 154 | session.add(db_config) 155 | 156 | try: 157 | session.commit() 158 | except IntegrityError as e: 159 | # TODO avoid users thinking something is wrong here 160 | log.warning(e) 161 | session.rollback() 162 | 163 | 164 | def save_client_config_to_db(clients): 165 | log.debug(f"save_client_config_to_db: {clients}") 166 | 167 | with Session(engine) as session: 168 | for client in iter(clients): 169 | stmt = select(WillowClientTable).where(WillowClientTable.mac_addr == client["mac_addr"]) 170 | record = session.exec(stmt).first() 171 | 172 | if record is None: 173 | record = WillowClientTable( 174 | label=client["label"], 175 | mac_addr=client["mac_addr"] 176 | ) 177 | 178 | else: 179 | if record.label == client["label"] and record.mac_addr == client["mac_addr"]: 180 | continue 181 | record.label = client["label"] 182 | record.mac_addr = client["mac_addr"] 183 | 184 | session.add(record) 185 | 186 | try: 187 | session.commit() 188 | except IntegrityError as e: 189 | log.warning(e) 190 | session.rollback() 191 | 192 | 193 | def save_config_to_db(config): 194 | config = WillowConfig.parse_obj(config) 195 | log.debug(f"save_config_to_db: {config}") 196 | 197 | with Session(engine) as session: 198 | for name, value in iter(config): 199 | stmt = select(WillowConfigTable).where( 200 | WillowConfigTable.config_type == WillowConfigType.config, 201 | WillowConfigTable.config_name == name, 202 | ) 203 | record = session.exec(stmt).first() 204 | 205 | if record is None: 206 | record = WillowConfigTable( 207 | config_type=WillowConfigType.config, 208 | config_name=name, 209 | config_value=convert_str_or_none(value) 210 | ) 211 | 212 | else: 213 | if record.config_value == convert_str_or_none(value): 214 | continue 215 | 216 | record.config_value = convert_str_or_none(value) 217 | 218 | session.add(record) 219 | 220 | try: 221 | session.commit() 222 | session.refresh(record) 223 | except IntegrityError as e: 224 | log.warning(e) 225 | session.rollback() 226 | 227 | 228 | def save_nvs_to_db(config): 229 | config = WillowNvsConfig.parse_obj(config) 230 | log.debug(f"save_nvs_to_db: {config}") 231 | 232 | with Session(engine) as session: 233 | for name, value in iter(config.WAS): 234 | stmt = select(WillowConfigTable).where( 235 | WillowConfigTable.config_type == WillowConfigType.nvs, 236 | WillowConfigTable.config_name == name, 237 | WillowConfigTable.config_namespace == WillowConfigNamespaceType.WAS, 238 | ) 239 | record = session.exec(stmt).first() 240 | 241 | if record is None: 242 | record = WillowConfigTable( 243 | config_type=WillowConfigType.nvs, 244 | config_name=name, 245 | config_namespace=WillowConfigNamespaceType.WAS, 246 | config_value=str(value), 247 | ) 248 | 249 | else: 250 | if record.config_value == str(value): 251 | continue 252 | record.config_value = str(value) 253 | 254 | session.add(record) 255 | 256 | for name, value in iter(config.WIFI): 257 | stmt = select(WillowConfigTable).where( 258 | WillowConfigTable.config_type == WillowConfigType.nvs, 259 | WillowConfigTable.config_name == name, 260 | WillowConfigTable.config_namespace == WillowConfigNamespaceType.WIFI, 261 | ) 262 | record = session.exec(stmt).first() 263 | 264 | if record is None: 265 | record = WillowConfigTable( 266 | config_type=WillowConfigType.nvs, 267 | config_name=name, 268 | config_namespace=WillowConfigNamespaceType.WIFI, 269 | config_value=str(value), 270 | ) 271 | 272 | else: 273 | if record.config_value == str(value): 274 | continue 275 | record.config_value = str(value) 276 | 277 | session.add(record) 278 | 279 | try: 280 | session.commit() 281 | session.refresh(record) 282 | except IntegrityError as e: 283 | log.warning(e) 284 | session.rollback() 285 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /app/internal/was.py: -------------------------------------------------------------------------------- 1 | import json 2 | import magic 3 | import os 4 | import random 5 | import re 6 | import requests 7 | import socket 8 | import time 9 | import urllib 10 | import urllib3 11 | 12 | from fastapi import HTTPException 13 | from hashlib import sha256 14 | from logging import getLogger 15 | 16 | from num2words import num2words 17 | from websockets.sync.client import connect 18 | 19 | from app.db.main import get_config_db, get_nvs_db, save_config_to_db, save_nvs_to_db 20 | 21 | from ..const import ( 22 | DIR_OTA, 23 | STORAGE_TZ, 24 | STORAGE_USER_CLIENT_CONFIG, 25 | STORAGE_USER_CONFIG, 26 | STORAGE_USER_MULTINET, 27 | STORAGE_USER_NVS, 28 | STORAGE_USER_WAS, 29 | URL_WILLOW_RELEASES, 30 | URL_WILLOW_TZ, 31 | ) 32 | 33 | 34 | log = getLogger("WAS") 35 | urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) 36 | 37 | 38 | def build_msg(config, container): 39 | try: 40 | msg = json.dumps({container: config}, sort_keys=True) 41 | return msg 42 | except Exception as e: 43 | log.error(f"Failed to build config message: {e}") 44 | 45 | 46 | def construct_url(host, port, tls=False, ws=False): 47 | if tls: 48 | if ws: 49 | scheme = "wss" 50 | else: 51 | scheme = "https" 52 | else: 53 | if ws: 54 | scheme = "ws" 55 | else: 56 | scheme = "http" 57 | 58 | return f"{scheme}://{host}:{port}" 59 | 60 | 61 | def construct_wis_tts_url(url): 62 | parsed = urllib.parse.urlparse(url) 63 | if len(parsed.query) == 0: 64 | return urllib.parse.urljoin(url, "?text=") 65 | else: 66 | params = urllib.parse.parse_qs(parsed.query) 67 | log.debug(f"construct_wis_tts_url: parsed={parsed} - params={params}") 68 | if "text" in params: 69 | log.warning("removing text parameter from WIS TTS URL") 70 | del params["text"] 71 | params["text"] = "" 72 | parsed = parsed._replace(query=urllib.parse.urlencode(params, doseq=True)) 73 | log.debug(f"construct_wis_tts_url: parsed={parsed} - params={params}") 74 | return urllib.parse.urlunparse(parsed) 75 | 76 | 77 | async def device_command(connmgr, data, command): 78 | if 'hostname' in data: 79 | hostname = data["hostname"] 80 | 81 | msg = json.dumps({'cmd': command}) 82 | try: 83 | ws = connmgr.get_client_by_hostname(hostname) 84 | await ws.send_text(msg) 85 | return "Success" 86 | except Exception as e: 87 | log.error(f"Failed to send restart command to {data['hostname']} ({e})") 88 | return "Error" 89 | 90 | 91 | def do_get_request(url, verify=False, timeout=(1, 60)): 92 | try: 93 | parsed_url = urllib.parse.urlparse(url) 94 | 95 | if parsed_url.username and parsed_url.password: 96 | # Request with auth 97 | basic_auth = requests.auth.HTTPBasicAuth(parsed_url.username, parsed_url.password) 98 | response = requests.get(url, verify=verify, timeout=timeout, auth=basic_auth) 99 | else: 100 | # Request without auth 101 | response = requests.get(url, verify=verify, timeout=timeout) 102 | return response 103 | 104 | except Exception as e: 105 | print(f"An error occurred: {str(e)}") 106 | return None 107 | 108 | 109 | def get_config(): 110 | return get_json_from_file(STORAGE_USER_CONFIG) 111 | 112 | 113 | def get_devices(): 114 | devices = [] 115 | 116 | if os.path.isfile(STORAGE_USER_CLIENT_CONFIG): 117 | with open(STORAGE_USER_CLIENT_CONFIG, "r") as devices_file: 118 | devices = json.load(devices_file) 119 | devices_file.close() 120 | else: 121 | with open(STORAGE_USER_CLIENT_CONFIG, "x") as devices_file: 122 | json.dump(devices, devices_file) 123 | devices_file.close() 124 | 125 | return devices 126 | 127 | 128 | def get_ha_commands_for_entity(entity): 129 | commands = [] 130 | pattern = r'[^A-Za-z- ]' 131 | 132 | numbers = re.search(r'(\d{1,})', entity) 133 | if numbers: 134 | for number in numbers.groups(): 135 | entity = entity.replace(number, f" {num2words(int(number))} ") 136 | 137 | entity = entity.replace('_', ' ') 138 | entity = re.sub(pattern, '', entity) 139 | entity = " ".join(entity.split()) 140 | entity = entity.upper() 141 | 142 | on = f'TURN ON {entity}' 143 | off = f'TURN OFF {entity}' 144 | 145 | # ESP_MN_MAX_PHRASE_LEN=63 146 | if len(off) < 63: 147 | commands.extend([on, off]) 148 | 149 | return commands 150 | 151 | 152 | def get_ha_entities(url, token): 153 | if token is None: 154 | return json.dumps({'error': 'HA token not set'}) 155 | 156 | headers = { 157 | "Authorization": f"Bearer {token}", 158 | "Content-Type": "application/json", 159 | } 160 | url = f"{url}/api/states" 161 | response = requests.get(url, headers=headers) 162 | data = response.json() 163 | data.sort(key=lambda x: x['entity_id']) 164 | return data 165 | 166 | 167 | def get_ip(): 168 | sk = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 169 | sk.connect(("1.1.1.1", 53)) 170 | ip = sk.getsockname()[0] 171 | sk.close() 172 | return ip 173 | 174 | 175 | def get_json_from_file(path): 176 | try: 177 | with open(path, "r") as file: 178 | data = json.load(file) 179 | file.close() 180 | except Exception: 181 | data = {} 182 | 183 | return data 184 | 185 | 186 | def get_mime_type(filename): 187 | mime_type = magic.Magic(mime=True).from_file(filename) 188 | return mime_type 189 | 190 | 191 | def get_multinet(): 192 | return get_json_from_file(STORAGE_USER_MULTINET) 193 | 194 | 195 | def get_nvs(): 196 | return get_json_from_file(STORAGE_USER_NVS) 197 | 198 | 199 | def get_release_url(was_url, version, platform): 200 | #url_parts = re.match(r"^(?:\w+:\/\/)?([^\/:]+)(?::(\d+))?", was_url) 201 | parsed = urllib.parse.urlparse(was_url) 202 | 203 | match parsed.scheme: 204 | case "ws": 205 | scheme = "http" 206 | case "wss": 207 | scheme = "https" 208 | 209 | url = f"{scheme}://{parsed.netloc}/api/ota?version={version}&platform={platform}" 210 | return url 211 | 212 | 213 | # TODO: Find a better way but we need to handle every error possible 214 | def get_releases_local(): 215 | local_dir = f"{DIR_OTA}/local" 216 | assets = [] 217 | if not os.path.exists(local_dir): 218 | return assets 219 | 220 | url = "https://heywillow.io" 221 | 222 | for asset_name in os.listdir(local_dir): 223 | if '.bin' in asset_name: 224 | file = f"{DIR_OTA}/local/{asset_name}" 225 | created_at = os.path.getctime(file) 226 | created_at = time.ctime(created_at) 227 | created_at = time.strptime(created_at) 228 | created_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", created_at) 229 | with open(file, "rb") as f: 230 | bytes = f.read() 231 | checksum = sha256(bytes).hexdigest() 232 | asset = {} 233 | asset["name"] = f"willow-ota-{asset_name}" 234 | asset["tag_name"] = f"willow-ota-{asset_name}" 235 | asset["platform"] = asset_name.replace('.bin', '') 236 | asset["platform_name"] = asset["platform"] 237 | asset["platform_image"] = "https://heywillow.io/images/esp32_s3_box.png" 238 | asset["build_type"] = "ota" 239 | asset["url"] = url 240 | asset["id"] = random.randint(10, 99) 241 | asset["content_type"] = "raw" 242 | asset["size"] = os.path.getsize(file) 243 | asset["created_at"] = created_at 244 | asset["browser_download_url"] = url 245 | asset["sha256"] = checksum 246 | assets.append(asset) 247 | 248 | if assets == []: 249 | return [] 250 | else: 251 | return [{"name": "local", 252 | "tag_name": "local", 253 | "id": random.randint(10, 99), 254 | "url": url, 255 | "html_url": url, 256 | "assets": assets}] 257 | 258 | 259 | def get_releases_willow(): 260 | releases = requests.get(URL_WILLOW_RELEASES) 261 | releases = releases.json() 262 | try: 263 | releases_local = get_releases_local() 264 | except Exception: 265 | pass 266 | else: 267 | releases = releases_local + releases 268 | return releases 269 | 270 | 271 | def get_safe_path(basedir, path, follow_symlinks=True): 272 | path = os.path.normpath(path) 273 | # resolves symbolic links 274 | if follow_symlinks: 275 | matchpath = os.path.realpath(path) 276 | else: 277 | matchpath = os.path.abspath(path) 278 | 279 | if not matchpath.startswith(basedir): 280 | raise HTTPException(status_code=400, detail=f"invalid asset path {path}") 281 | 282 | return matchpath 283 | 284 | 285 | def get_tz_config(refresh=False): 286 | if refresh: 287 | tz = requests.get(URL_WILLOW_TZ).json() 288 | with open(STORAGE_TZ, "w") as tz_file: 289 | json.dump(tz, tz_file) 290 | tz_file.close() 291 | 292 | return get_json_from_file(STORAGE_TZ) 293 | 294 | 295 | def get_was_config(): 296 | return get_json_from_file(STORAGE_USER_WAS) 297 | 298 | 299 | def get_was_url(): 300 | try: 301 | nvs = get_nvs_db() 302 | return nvs["WAS"]["URL"] 303 | except Exception: 304 | return False 305 | 306 | 307 | def merge_dict(dict_1, dict_2): 308 | result = dict_1 | dict_2 309 | return result 310 | 311 | 312 | async def post_config(request, apply=False): 313 | data = await request.json() 314 | if 'hostname' in data: 315 | hostname = data["hostname"] 316 | data = get_config_db() 317 | msg = build_msg(data, "config") 318 | try: 319 | ws = request.app.connmgr.get_client_by_hostname(hostname) 320 | await ws.send_text(msg) 321 | return "Success" 322 | except Exception as e: 323 | log.error(f"Failed to apply config to {hostname} ({e})") 324 | return "Error" 325 | else: 326 | if "wis_tts_url" in data: 327 | data["wis_tts_url_v2"] = construct_wis_tts_url(data["wis_tts_url"]) 328 | del data["wis_tts_url"] 329 | log.debug(f"wis_tts_url_v2: {data['wis_tts_url_v2']}") 330 | 331 | save_config_to_db(data) 332 | msg = build_msg(data, "config") 333 | log.debug(str(msg)) 334 | if apply: 335 | await request.app.connmgr.broadcast(msg) 336 | return "Success" 337 | 338 | 339 | async def post_nvs(request, apply=False): 340 | data = await request.json() 341 | if 'hostname' in data: 342 | hostname = data["hostname"] 343 | data = get_nvs_db() 344 | msg = build_msg(data, "nvs") 345 | try: 346 | ws = request.app.connmgr.get_client_by_hostname(hostname) 347 | await ws.send_text(msg) 348 | return "Success" 349 | except Exception as e: 350 | log.error(f"Failed to apply config to {hostname} ({e})") 351 | return "Error" 352 | else: 353 | save_nvs_to_db(data) 354 | msg = build_msg(data, "nvs") 355 | log.debug(str(msg)) 356 | if apply: 357 | await request.app.connmgr.broadcast(msg) 358 | return "Success" 359 | 360 | 361 | async def post_was(request, apply=False): 362 | data = await request.json() 363 | data = json.dumps(data) 364 | save_json_to_file(STORAGE_USER_WAS, data) 365 | return "Success" 366 | 367 | 368 | def save_json_to_file(path, content): 369 | with open(path, "w") as config_file: 370 | config_file.write(content) 371 | config_file.close() 372 | 373 | 374 | def warm_tts(data): 375 | try: 376 | if "/api/tts" in data["audio_url"]: 377 | do_get_request(data["audio_url"]) 378 | log.debug("TTS ready - passing to clients") 379 | except Exception: 380 | pass 381 | --------------------------------------------------------------------------------