Start type annotating/testing helpers (#17858)
* Add type hints to helpers.intent and location * Test typing for helpers.icon, json, and typing * Add type hints to helpers.state * Add type hints to helpers.translation
This commit is contained in:
parent
0f877711a0
commit
c9c707e368
5 changed files with 92 additions and 55 deletions
|
@ -1,8 +1,12 @@
|
|||
"""Helpers that help with state related things."""
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from types import TracebackType
|
||||
from typing import ( # noqa: F401 pylint: disable=unused-import
|
||||
Awaitable, Dict, Iterable, List, Optional, Tuple, Type, Union)
|
||||
|
||||
from homeassistant.loader import bind_hass
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
@ -42,6 +46,7 @@ from homeassistant.const import (
|
|||
STATE_UNLOCKED, SERVICE_SELECT_OPTION)
|
||||
from homeassistant.core import State
|
||||
from homeassistant.util.async_ import run_coroutine_threadsafe
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -102,43 +107,50 @@ class AsyncTrackStates:
|
|||
Must be run within the event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass: HomeAssistantType) -> None:
|
||||
"""Initialize a TrackStates block."""
|
||||
self.hass = hass
|
||||
self.states = []
|
||||
self.states = [] # type: List[State]
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> List[State]:
|
||||
"""Record time from which to track changes."""
|
||||
self.now = dt_util.utcnow()
|
||||
return self.states
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType]) -> None:
|
||||
"""Add changes states to changes list."""
|
||||
self.states.extend(get_changed_since(self.hass.states.async_all(),
|
||||
self.now))
|
||||
|
||||
|
||||
def get_changed_since(states, utc_point_in_time):
|
||||
def get_changed_since(states: Iterable[State],
|
||||
utc_point_in_time: dt.datetime) -> List[State]:
|
||||
"""Return list of states that have been changed since utc_point_in_time."""
|
||||
return [state for state in states
|
||||
if state.last_updated >= utc_point_in_time]
|
||||
|
||||
|
||||
@bind_hass
|
||||
def reproduce_state(hass, states, blocking=False):
|
||||
def reproduce_state(hass: HomeAssistantType,
|
||||
states: Union[State, Iterable[State]],
|
||||
blocking: bool = False) -> None:
|
||||
"""Reproduce given state."""
|
||||
return run_coroutine_threadsafe(
|
||||
return run_coroutine_threadsafe( # type: ignore
|
||||
async_reproduce_state(hass, states, blocking), hass.loop).result()
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_reproduce_state(hass, states, blocking=False):
|
||||
async def async_reproduce_state(hass: HomeAssistantType,
|
||||
states: Union[State, Iterable[State]],
|
||||
blocking: bool = False) -> None:
|
||||
"""Reproduce given state."""
|
||||
if isinstance(states, State):
|
||||
states = [states]
|
||||
|
||||
to_call = defaultdict(list)
|
||||
to_call = defaultdict(list) # type: Dict[Tuple[str, str, str], List[str]]
|
||||
|
||||
for state in states:
|
||||
|
||||
|
@ -182,7 +194,7 @@ async def async_reproduce_state(hass, states, blocking=False):
|
|||
json.dumps(dict(state.attributes), sort_keys=True))
|
||||
to_call[key].append(state.entity_id)
|
||||
|
||||
domain_tasks = {}
|
||||
domain_tasks = {} # type: Dict[str, List[Awaitable[Optional[bool]]]]
|
||||
for (service_domain, service, service_data), entity_ids in to_call.items():
|
||||
data = json.loads(service_data)
|
||||
data[ATTR_ENTITY_ID] = entity_ids
|
||||
|
@ -194,7 +206,8 @@ async def async_reproduce_state(hass, states, blocking=False):
|
|||
hass.services.async_call(service_domain, service, data, blocking)
|
||||
)
|
||||
|
||||
async def async_handle_service_calls(coro_list):
|
||||
async def async_handle_service_calls(
|
||||
coro_list: Iterable[Awaitable]) -> None:
|
||||
"""Handle service calls by domain sequence."""
|
||||
for coro in coro_list:
|
||||
await coro
|
||||
|
@ -205,7 +218,7 @@ async def async_reproduce_state(hass, states, blocking=False):
|
|||
await asyncio.wait(execute_tasks, loop=hass.loop)
|
||||
|
||||
|
||||
def state_as_number(state):
|
||||
def state_as_number(state: State) -> float:
|
||||
"""
|
||||
Try to coerce our state to a number.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue