Start type annotating/testing helpers (#17858)

* Add type hints to helpers.intent and location

* Test typing for helpers.icon, json, and typing

* Add type hints to helpers.state

* Add type hints to helpers.translation
This commit is contained in:
Ville Skyttä 2018-10-28 21:12:52 +02:00 committed by GitHub
parent 0f877711a0
commit c9c707e368
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 92 additions and 55 deletions

View file

@ -1,8 +1,12 @@
"""Helpers that help with state related things."""
import asyncio
import datetime as dt
import json
import logging
from collections import defaultdict
from types import TracebackType
from typing import ( # noqa: F401 pylint: disable=unused-import
Awaitable, Dict, Iterable, List, Optional, Tuple, Type, Union)
from homeassistant.loader import bind_hass
import homeassistant.util.dt as dt_util
@ -42,6 +46,7 @@ from homeassistant.const import (
STATE_UNLOCKED, SERVICE_SELECT_OPTION)
from homeassistant.core import State
from homeassistant.util.async_ import run_coroutine_threadsafe
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
@ -102,43 +107,50 @@ class AsyncTrackStates:
Must be run within the event loop.
"""
def __init__(self, hass):
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize a TrackStates block."""
self.hass = hass
self.states = []
self.states = [] # type: List[State]
# pylint: disable=attribute-defined-outside-init
def __enter__(self):
def __enter__(self) -> List[State]:
"""Record time from which to track changes."""
self.now = dt_util.utcnow()
return self.states
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
"""Add changes states to changes list."""
self.states.extend(get_changed_since(self.hass.states.async_all(),
self.now))
def get_changed_since(states, utc_point_in_time):
def get_changed_since(states: Iterable[State],
utc_point_in_time: dt.datetime) -> List[State]:
"""Return list of states that have been changed since utc_point_in_time."""
return [state for state in states
if state.last_updated >= utc_point_in_time]
@bind_hass
def reproduce_state(hass, states, blocking=False):
def reproduce_state(hass: HomeAssistantType,
states: Union[State, Iterable[State]],
blocking: bool = False) -> None:
"""Reproduce given state."""
return run_coroutine_threadsafe(
return run_coroutine_threadsafe( # type: ignore
async_reproduce_state(hass, states, blocking), hass.loop).result()
@bind_hass
async def async_reproduce_state(hass, states, blocking=False):
async def async_reproduce_state(hass: HomeAssistantType,
states: Union[State, Iterable[State]],
blocking: bool = False) -> None:
"""Reproduce given state."""
if isinstance(states, State):
states = [states]
to_call = defaultdict(list)
to_call = defaultdict(list) # type: Dict[Tuple[str, str, str], List[str]]
for state in states:
@ -182,7 +194,7 @@ async def async_reproduce_state(hass, states, blocking=False):
json.dumps(dict(state.attributes), sort_keys=True))
to_call[key].append(state.entity_id)
domain_tasks = {}
domain_tasks = {} # type: Dict[str, List[Awaitable[Optional[bool]]]]
for (service_domain, service, service_data), entity_ids in to_call.items():
data = json.loads(service_data)
data[ATTR_ENTITY_ID] = entity_ids
@ -194,7 +206,8 @@ async def async_reproduce_state(hass, states, blocking=False):
hass.services.async_call(service_domain, service, data, blocking)
)
async def async_handle_service_calls(coro_list):
async def async_handle_service_calls(
coro_list: Iterable[Awaitable]) -> None:
"""Handle service calls by domain sequence."""
for coro in coro_list:
await coro
@ -205,7 +218,7 @@ async def async_reproduce_state(hass, states, blocking=False):
await asyncio.wait(execute_tasks, loop=hass.loop)
def state_as_number(state):
def state_as_number(state: State) -> float:
"""
Try to coerce our state to a number.