Clean up entity component (#11691)

* Clean up entity component

* Lint

* List -> Tuple

* Add Entity.async_remove back

* Unflake setting up group test
This commit is contained in:
Paulus Schoutsen 2018-01-22 22:54:41 -08:00 committed by GitHub
parent d478517c51
commit 183e0543b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 230 additions and 191 deletions

View file

@ -1,6 +1,7 @@
"""Helpers for components that manage entities."""
import asyncio
from datetime import timedelta
from itertools import chain
from homeassistant import config as conf_util
from homeassistant.setup import async_prepare_setup_platform
@ -9,7 +10,6 @@ from homeassistant.const import (
DEVICE_DEFAULT_NAME)
from homeassistant.core import callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.loader import get_component
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.event import (
@ -27,7 +27,15 @@ PLATFORM_NOT_READY_RETRIES = 10
class EntityComponent(object):
"""Helper class that will help a component manage its entities."""
"""The EntityComponent manages platforms that manages entities.
This class has the following responsibilities:
- Process the configuration and set up a platform based component.
- Manage the platforms and their entities.
- Help extract the entities from a service call.
- Maintain a group that tracks all platform entities.
- Listen for discovery events for platforms related to the domain.
"""
def __init__(self, logger, domain, hass,
scan_interval=DEFAULT_SCAN_INTERVAL, group_name=None):
@ -40,7 +48,6 @@ class EntityComponent(object):
self.scan_interval = scan_interval
self.group_name = group_name
self.entities = {}
self.config = None
self._platforms = {
@ -49,6 +56,20 @@ class EntityComponent(object):
self.async_add_entities = self._platforms['core'].async_add_entities
self.add_entities = self._platforms['core'].add_entities
@property
def entities(self):
"""Return an iterable that returns all entities."""
return chain.from_iterable(platform.entities.values() for platform
in self._platforms.values())
def get_entity(self, entity_id):
"""Helper method to get an entity."""
for platform in self._platforms.values():
entity = platform.entities.get(entity_id)
if entity is not None:
return entity
return None
def setup(self, config):
"""Set up a full entity component.
@ -77,11 +98,10 @@ class EntityComponent(object):
# Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.components.discovery.load_platform()
@callback
@asyncio.coroutine
def component_platform_discovered(platform, info):
"""Handle the loading of a platform."""
self.hass.async_add_job(
self._async_setup_platform(platform, {}, info))
yield from self._async_setup_platform(platform, {}, info)
discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered)
@ -107,13 +127,11 @@ class EntityComponent(object):
This method must be run in the event loop.
"""
if ATTR_ENTITY_ID not in service.data:
return [entity for entity in self.entities.values()
if entity.available]
return [entity for entity in self.entities if entity.available]
return [self.entities[entity_id] for entity_id
in extract_entity_ids(self.hass, service, expand_group)
if entity_id in self.entities and
self.entities[entity_id].available]
entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
return [entity for entity in self.entities
if entity.available and entity.entity_id in entity_ids]
@asyncio.coroutine
def _async_setup_platform(self, platform_type, platform_config,
@ -193,80 +211,23 @@ class EntityComponent(object):
finally:
warn_task.cancel()
def add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component."""
return run_coroutine_threadsafe(
self.async_add_entity(entity, platform, update_before_add),
self.hass.loop
).result()
@asyncio.coroutine
def async_add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component.
This method must be run in the event loop.
"""
if entity is None or entity in self.entities.values():
return False
entity.hass = self.hass
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.logger.exception("Error on device update!")
return False
# Write entity_id to entity
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
if platform is not None and platform.entity_namespace is not None:
object_id = '{} {}'.format(platform.entity_namespace,
object_id)
entity.entity_id = async_generate_entity_id(
self.entity_id_format, object_id,
self.entities.keys())
# Make sure it is valid in case an entity set the value themselves
if entity.entity_id in self.entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
elif not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
return True
def update_group(self):
"""Set up and/or update component group."""
run_callback_threadsafe(
self.hass.loop, self.async_update_group).result()
@callback
def async_update_group(self):
"""Set up and/or update component group.
This method must be run in the event loop.
"""
if self.group_name is not None:
ids = sorted(self.entities,
key=lambda x: self.entities[x].name or x)
group = get_component('group')
group.async_set_group(
self.hass, slugify(self.group_name), name=self.group_name,
visible=False, entity_ids=ids
)
if self.group_name is None:
return
ids = [entity.entity_id for entity in
sorted(self.entities,
key=lambda entity: entity.name or entity.entity_id)]
self.hass.components.group.async_set_group(
slugify(self.group_name), name=self.group_name,
visible=False, entity_ids=ids
)
def reset(self):
"""Remove entities and reset the entity component to initial values."""
@ -287,12 +248,17 @@ class EntityComponent(object):
self._platforms = {
'core': self._platforms['core']
}
self.entities = {}
self.config = None
if self.group_name is not None:
group = get_component('group')
group.async_remove(self.hass, slugify(self.group_name))
self.hass.components.group.async_remove(slugify(self.group_name))
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove an entity managed by one of the platforms."""
for platform in self._platforms.values():
if entity_id in platform.entities:
yield from platform.async_remove_entity(entity_id)
def prepare_reload(self):
"""Prepare reloading this entity component."""
@ -323,7 +289,7 @@ class EntityComponent(object):
class EntityPlatform(object):
"""Keep track of entities for a single platform and stay in loop."""
"""Manage the entities for a single platform."""
def __init__(self, component, platform, scan_interval, parallel_updates,
entity_namespace):
@ -333,7 +299,7 @@ class EntityPlatform(object):
self.scan_interval = scan_interval
self.parallel_updates = None
self.entity_namespace = entity_namespace
self.platform_entities = []
self.entities = {}
self._tasks = []
self._async_unsub_polling = None
self._process_updates = asyncio.Lock(loop=component.hass.loop)
@ -391,40 +357,88 @@ class EntityPlatform(object):
if not new_entities:
return
@asyncio.coroutine
def async_process_entity(new_entity):
"""Add entities to StateMachine."""
new_entity.parallel_updates = self.parallel_updates
ret = yield from self.component.async_add_entity(
new_entity, self, update_before_add=update_before_add
)
if ret:
self.platform_entities.append(new_entity)
component_entities = set(entity.entity_id for entity
in self.component.entities)
tasks = [async_process_entity(entity) for entity in new_entities]
tasks = [
self._async_add_entity(entity, update_before_add,
component_entities)
for entity in new_entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
self.component.async_update_group()
if self._async_unsub_polling is not None or \
not any(entity.should_poll for entity
in self.platform_entities):
in self.entities.values()):
return
self._async_unsub_polling = async_track_time_interval(
self.component.hass, self._update_entity_states, self.scan_interval
)
@asyncio.coroutine
def _async_add_entity(self, entity, update_before_add, component_entities):
"""Helper method to add an entity to the platform."""
if entity is None:
raise ValueError('Entity cannot be None')
# Do nothing if entity has already been added based on unique id.
if entity in self.component.entities:
return
entity.hass = self.component.hass
entity.platform = self
entity.parallel_updates = self.parallel_updates
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.component.logger.exception(
"%s: Error on device update!", self.platform)
return
# Write entity_id to entity
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
if self.entity_namespace is not None:
object_id = '{} {}'.format(self.entity_namespace,
object_id)
entity.entity_id = async_generate_entity_id(
self.component.entity_id_format, object_id,
component_entities)
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
elif entity.entity_id in component_entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
component_entities.add(entity.entity_id)
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
@asyncio.coroutine
def async_reset(self):
"""Remove all entities and reset data.
This method must be run in the event loop.
"""
if not self.platform_entities:
if not self.entities:
return
tasks = [entity.async_remove() for entity in self.platform_entities]
tasks = [self._async_remove_entity(entity_id)
for entity_id in self.entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
@ -432,6 +446,28 @@ class EntityPlatform(object):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
yield from self._async_remove_entity(entity_id)
# Clean up polling job if no longer needed
if (self._async_unsub_polling is not None and
not any(entity.should_poll for entity
in self.entities.values())):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def _async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
entity = self.entities.pop(entity_id)
if hasattr(entity, 'async_will_remove_from_hass'):
yield from entity.async_will_remove_from_hass()
self.component.hass.states.async_remove(entity_id)
@asyncio.coroutine
def _update_entity_states(self, now):
"""Update the states of all the polling entities.
@ -450,7 +486,7 @@ class EntityPlatform(object):
with (yield from self._process_updates):
tasks = []
for entity in self.platform_entities:
for entity in self.entities.values():
if not entity.should_poll:
continue
tasks.append(entity.async_update_ha_state(True))