Entity layer cleanup (#12237)

* Simplify entity update

* Split entity platform from entity component

* Decouple entity platform from entity component

* Always include unit of measurement again

* Lint

* Fix test
This commit is contained in:
Paulus Schoutsen 2018-02-08 03:16:51 -08:00 committed by Pascal Vizeli
parent 8523933605
commit 5601fbdc7a
7 changed files with 905 additions and 857 deletions

View file

@ -4,67 +4,27 @@ import asyncio
from collections import OrderedDict
import logging
import unittest
from unittest.mock import patch, Mock, MagicMock
from unittest.mock import patch, Mock
from datetime import timedelta
import homeassistant.core as ha
import homeassistant.loader as loader
from homeassistant.exceptions import PlatformNotReady
from homeassistant.components import group
from homeassistant.helpers.entity import Entity, generate_entity_id
from homeassistant.helpers.entity_component import (
EntityComponent, DEFAULT_SCAN_INTERVAL, SLOW_SETUP_WARNING)
from homeassistant.helpers import entity_component
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import setup_component
from homeassistant.helpers import discovery
import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, MockModule, fire_time_changed,
mock_coro, async_fire_time_changed, mock_registry)
get_test_home_assistant, MockPlatform, MockModule, mock_coro,
async_fire_time_changed, MockEntity)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
class EntityTest(Entity):
"""Test for the Entity component."""
def __init__(self, **values):
"""Initialize an entity."""
self._values = values
if 'entity_id' in values:
self.entity_id = values['entity_id']
@property
def name(self):
"""Return the name of the entity."""
return self._handle('name')
@property
def should_poll(self):
"""Return the ste of the polling."""
return self._handle('should_poll')
@property
def unique_id(self):
"""Return the unique ID of the entity."""
return self._handle('unique_id')
@property
def available(self):
"""Return True if entity is available."""
return self._handle('available')
def _handle(self, attr):
"""Helper for the attributes."""
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
class TestHelpersEntityComponent(unittest.TestCase):
"""Test homeassistant.helpers.entity_component module."""
@ -85,7 +45,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
# No group after setup
assert len(self.hass.states.entity_ids()) == 0
component.add_entities([EntityTest()])
component.add_entities([MockEntity()])
self.hass.block_till_done()
# group exists
@ -98,7 +58,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
('test_domain.unnamed_device',)
# group extended
component.add_entities([EntityTest(name='goodbye')])
component.add_entities([MockEntity(name='goodbye')])
self.hass.block_till_done()
assert len(self.hass.states.entity_ids()) == 3
@ -108,151 +68,6 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert group.attributes.get('entity_id') == \
('test_domain.goodbye', 'test_domain.unnamed_device')
def test_polling_only_updates_entities_it_should_poll(self):
"""Test the polling of only updated entities."""
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
no_poll_ent = EntityTest(should_poll=False)
no_poll_ent.async_update = Mock()
poll_ent = EntityTest(should_poll=True)
poll_ent.async_update = Mock()
component.add_entities([no_poll_ent, poll_ent])
no_poll_ent.async_update.reset_mock()
poll_ent.async_update.reset_mock()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done()
assert not no_poll_ent.async_update.called
assert poll_ent.async_update.called
def test_polling_updates_entities_with_exception(self):
"""Test the updated entities that not break with an exception."""
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
update_ok = []
update_err = []
def update_mock():
"""Mock normal update."""
update_ok.append(None)
def update_mock_err():
"""Mock error update."""
update_err.append(None)
raise AssertionError("Fake error update")
ent1 = EntityTest(should_poll=True)
ent1.update = update_mock_err
ent2 = EntityTest(should_poll=True)
ent2.update = update_mock
ent3 = EntityTest(should_poll=True)
ent3.update = update_mock
ent4 = EntityTest(should_poll=True)
ent4.update = update_mock
component.add_entities([ent1, ent2, ent3, ent4])
update_ok.clear()
update_err.clear()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done()
assert len(update_ok) == 3
assert len(update_err) == 1
def test_update_state_adds_entities(self):
"""Test if updating poll entities cause an entity to be added works."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent1 = EntityTest()
ent2 = EntityTest(should_poll=True)
component.add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids())
ent2.update = lambda *_: component.add_entities([ent1])
fire_time_changed(
self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL
)
self.hass.block_till_done()
assert 2 == len(self.hass.states.entity_ids())
def test_update_state_adds_entities_with_update_before_add_true(self):
"""Test if call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = EntityTest()
ent.update = Mock(spec_set=True)
component.add_entities([ent], True)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert ent.update.called
def test_update_state_adds_entities_with_update_before_add_false(self):
"""Test if not call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = EntityTest()
ent.update = Mock(spec_set=True)
component.add_entities([ent], False)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert not ent.update.called
def test_extract_from_service_returns_all_if_no_entity_id(self):
"""Test the extraction of everything from service."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service')
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.extract_from_service(call))
def test_extract_from_service_filter_out_non_existing_entities(self):
"""Test the extraction of non existing entities from service."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['test_domain.test_2', 'test_domain.non_exist']
})
assert ['test_domain.test_2'] == \
[ent.entity_id for ent in component.extract_from_service(call)]
def test_extract_from_service_no_group_expand(self):
"""Test not expanding a group."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
test_group = group.Group.create_group(
self.hass, 'test_group', ['light.Ceiling', 'light.Kitchen'])
component.add_entities([test_group])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['group.test_group']
})
extracted = component.extract_from_service(call, expand_group=False)
self.assertEqual([test_group], extracted)
def test_setup_loads_platforms(self):
"""Test the loading of the platforms."""
component_setup = Mock(return_value=True)
@ -320,13 +135,13 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert ('platform_test', {}, {'msg': 'discovery_info'}) == \
mock_setup.call_args[0]
@patch('homeassistant.helpers.entity_component.'
@patch('homeassistant.helpers.entity_platform.'
'async_track_time_interval')
def test_set_scan_interval_via_config(self, mock_track):
"""Test the setting of the scan interval via configuration."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([EntityTest(should_poll=True)])
add_devices([MockEntity(should_poll=True)])
loader.set_component('test_domain.platform',
MockPlatform(platform_setup))
@ -344,38 +159,13 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
@patch('homeassistant.helpers.entity_component.'
'async_track_time_interval')
def test_set_scan_interval_via_platform(self, mock_track):
"""Test the setting of the scan interval via platform."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([EntityTest(should_poll=True)])
platform = MockPlatform(platform_setup)
platform.SCAN_INTERVAL = timedelta(seconds=30)
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
}
})
self.hass.block_till_done()
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
def test_set_entity_namespace_via_config(self):
"""Test setting an entity namespace."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([
EntityTest(name='beer'),
EntityTest(name=None),
MockEntity(name='beer'),
MockEntity(name=None),
])
platform = MockPlatform(platform_setup)
@ -396,83 +186,16 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert sorted(self.hass.states.entity_ids()) == \
['test_domain.yummy_beer', 'test_domain.yummy_unnamed_device']
def test_adding_entities_with_generator_and_thread_callback(self):
"""Test generator in add_entities that calls thread method.
We should make sure we resolve the generator to a list before passing
it into an async context.
"""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
def create_entity(number):
"""Create entity helper."""
entity = EntityTest()
entity.entity_id = generate_entity_id(component.entity_id_format,
'Number', hass=self.hass)
return entity
component.add_entities(create_entity(i) for i in range(2))
@asyncio.coroutine
def test_platform_warn_slow_setup(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
with patch.object(hass.loop, 'call_later', MagicMock()) \
as mock_call:
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
assert mock_call.called
timeout, logger_method = mock_call.mock_calls[0][1][:2]
assert timeout == SLOW_SETUP_WARNING
assert logger_method == _LOGGER.warning
assert mock_call().cancel.called
@asyncio.coroutine
def test_platform_error_slow_setup(hass, caplog):
"""Don't block startup more than SLOW_SETUP_MAX_WAIT."""
with patch.object(entity_component, 'SLOW_SETUP_MAX_WAIT', 0):
called = []
@asyncio.coroutine
def setup_platform(*args):
called.append(1)
yield from asyncio.sleep(1, loop=hass.loop)
platform = MockPlatform(async_setup_platform=setup_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
loader.set_component('test_domain.test_platform', platform)
yield from component.async_setup({
DOMAIN: {
'platform': 'test_platform',
}
})
assert len(called) == 1
assert 'test_domain.test_platform' not in hass.config.components
assert 'test_platform is taking longer than 0 seconds' in caplog.text
@asyncio.coroutine
def test_extract_from_service_available_device(hass):
"""Test the extraction of entity from service and device is available."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2', available=False),
EntityTest(name='test_3'),
EntityTest(name='test_4', available=False),
MockEntity(name='test_1'),
MockEntity(name='test_2', available=False),
MockEntity(name='test_3'),
MockEntity(name='test_4', available=False),
])
call_1 = ha.ServiceCall('test', 'service')
@ -490,26 +213,6 @@ def test_extract_from_service_available_device(hass):
component.async_extract_from_service(call_2))
@asyncio.coroutine
def test_updated_state_used_for_entity_id(hass):
"""Test that first update results used for entity ID generation."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
class EntityTestNameFetcher(EntityTest):
"""Mock entity that fetches a friendly name."""
@asyncio.coroutine
def async_update(self):
"""Mock update that assigns a name."""
self._values['name'] = "Living Room"
yield from component.async_add_entities([EntityTestNameFetcher()], True)
entity_ids = hass.states.async_entity_ids()
assert 1 == len(entity_ids)
assert entity_ids[0] == "test_domain.living_room"
@asyncio.coroutine
def test_platform_not_ready(hass):
"""Test that we retry when platform not ready."""
@ -555,188 +258,50 @@ def test_platform_not_ready(hass):
@asyncio.coroutine
def test_parallel_updates_async_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
loader.set_component('test_domain.platform', platform)
def test_extract_from_service_returns_all_if_no_entity_id(hass):
"""Test the extraction of everything from service."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_add_entities([
MockEntity(name='test_1'),
MockEntity(name='test_2'),
])
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
call = ha.ServiceCall('test', 'service')
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call))
@asyncio.coroutine
def test_extract_from_service_filter_out_non_existing_entities(hass):
"""Test the extraction of non existing entities from service."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
MockEntity(name='test_1'),
MockEntity(name='test_2'),
])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['test_domain.test_2', 'test_domain.non_exist']
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
assert ['test_domain.test_2'] == \
[ent.entity_id for ent
in component.async_extract_from_service(call)]
@asyncio.coroutine
def test_parallel_updates_async_platform_with_constant(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
platform.PARALLEL_UPDATES = 1
loader.set_component('test_domain.platform', platform)
def test_extract_from_service_no_group_expand(hass):
"""Test not expanding a group."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
test_group = yield from group.Group.async_create_group(
hass, 'test_group', ['light.Ceiling', 'light.Kitchen'])
yield from component.async_add_entities([test_group])
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
call = ha.ServiceCall('test', 'service', {
'entity_id': ['group.test_group']
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is not None
@asyncio.coroutine
def test_parallel_updates_sync_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is not None
@asyncio.coroutine
def test_raise_error_on_update(hass):
"""Test the add entity if they raise an error on update."""
updates = []
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = EntityTest(name='test_1')
entity2 = EntityTest(name='test_2')
def _raise():
"""Helper to raise an exception."""
raise AssertionError
entity1.update = _raise
entity2.update = lambda: updates.append(1)
yield from component.async_add_entities([entity1, entity2], True)
assert len(updates) == 1
assert 1 in updates
@asyncio.coroutine
def test_async_remove_with_platform(hass):
"""Remove an entity from a platform."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = EntityTest(name='test_1')
yield from component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
yield from entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0
@asyncio.coroutine
def test_not_adding_duplicate_entities_with_unique_id(hass):
"""Test for not adding duplicate entities."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(name='test1', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1
yield from component.async_add_entities([
EntityTest(name='test2', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1
@asyncio.coroutine
def test_using_prescribed_entity_id(hass):
"""Test for using predefined entity ID."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(name='bla', entity_id='hello.world')])
assert 'hello.world' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_using_prescribed_entity_id_with_unique_id(hass):
"""Test for ammending predefined entity ID because currently exists."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(entity_id='test_domain.world')])
yield from component.async_add_entities([
EntityTest(entity_id='test_domain.world', unique_id='bla')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_using_prescribed_entity_id_which_is_registered(hass):
"""Test not allowing predefined entity ID that already registered."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass)
# Register test_domain.world
registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world')
# This entity_id will be rewritten
yield from component.async_add_entities([
EntityTest(entity_id='test_domain.world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_name_which_conflict_with_registered(hass):
"""Test not generating conflicting entity ID based on name."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass)
# Register test_domain.world
registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world')
yield from component.async_add_entities([
EntityTest(name='world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_entity_with_name_and_entity_id_getting_registered(hass):
"""Ensure that entity ID is used for registration."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(unique_id='1234', name='bla',
entity_id='test_domain.world')])
assert 'test_domain.world' in hass.states.async_entity_ids()
extracted = component.async_extract_from_service(call, expand_group=False)
assert extracted == [test_group]