Clean up entity component (#11691)
* Clean up entity component * Lint * List -> Tuple * Add Entity.async_remove back * Unflake setting up group test
This commit is contained in:
parent
d478517c51
commit
183e0543b4
14 changed files with 230 additions and 191 deletions
|
@ -1,6 +1,7 @@
|
|||
"""Helpers for components that manage entities."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from itertools import chain
|
||||
|
||||
from homeassistant import config as conf_util
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
@ -9,7 +10,6 @@ from homeassistant.const import (
|
|||
DEVICE_DEFAULT_NAME)
|
||||
from homeassistant.core import callback, valid_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.loader import get_component
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers.entity import async_generate_entity_id
|
||||
from homeassistant.helpers.event import (
|
||||
|
@ -27,7 +27,15 @@ PLATFORM_NOT_READY_RETRIES = 10
|
|||
|
||||
|
||||
class EntityComponent(object):
|
||||
"""Helper class that will help a component manage its entities."""
|
||||
"""The EntityComponent manages platforms that manages entities.
|
||||
|
||||
This class has the following responsibilities:
|
||||
- Process the configuration and set up a platform based component.
|
||||
- Manage the platforms and their entities.
|
||||
- Help extract the entities from a service call.
|
||||
- Maintain a group that tracks all platform entities.
|
||||
- Listen for discovery events for platforms related to the domain.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, domain, hass,
|
||||
scan_interval=DEFAULT_SCAN_INTERVAL, group_name=None):
|
||||
|
@ -40,7 +48,6 @@ class EntityComponent(object):
|
|||
self.scan_interval = scan_interval
|
||||
self.group_name = group_name
|
||||
|
||||
self.entities = {}
|
||||
self.config = None
|
||||
|
||||
self._platforms = {
|
||||
|
@ -49,6 +56,20 @@ class EntityComponent(object):
|
|||
self.async_add_entities = self._platforms['core'].async_add_entities
|
||||
self.add_entities = self._platforms['core'].add_entities
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
"""Return an iterable that returns all entities."""
|
||||
return chain.from_iterable(platform.entities.values() for platform
|
||||
in self._platforms.values())
|
||||
|
||||
def get_entity(self, entity_id):
|
||||
"""Helper method to get an entity."""
|
||||
for platform in self._platforms.values():
|
||||
entity = platform.entities.get(entity_id)
|
||||
if entity is not None:
|
||||
return entity
|
||||
return None
|
||||
|
||||
def setup(self, config):
|
||||
"""Set up a full entity component.
|
||||
|
||||
|
@ -77,11 +98,10 @@ class EntityComponent(object):
|
|||
|
||||
# Generic discovery listener for loading platform dynamically
|
||||
# Refer to: homeassistant.components.discovery.load_platform()
|
||||
@callback
|
||||
@asyncio.coroutine
|
||||
def component_platform_discovered(platform, info):
|
||||
"""Handle the loading of a platform."""
|
||||
self.hass.async_add_job(
|
||||
self._async_setup_platform(platform, {}, info))
|
||||
yield from self._async_setup_platform(platform, {}, info)
|
||||
|
||||
discovery.async_listen_platform(
|
||||
self.hass, self.domain, component_platform_discovered)
|
||||
|
@ -107,13 +127,11 @@ class EntityComponent(object):
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
if ATTR_ENTITY_ID not in service.data:
|
||||
return [entity for entity in self.entities.values()
|
||||
if entity.available]
|
||||
return [entity for entity in self.entities if entity.available]
|
||||
|
||||
return [self.entities[entity_id] for entity_id
|
||||
in extract_entity_ids(self.hass, service, expand_group)
|
||||
if entity_id in self.entities and
|
||||
self.entities[entity_id].available]
|
||||
entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
|
||||
return [entity for entity in self.entities
|
||||
if entity.available and entity.entity_id in entity_ids]
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_setup_platform(self, platform_type, platform_config,
|
||||
|
@ -193,80 +211,23 @@ class EntityComponent(object):
|
|||
finally:
|
||||
warn_task.cancel()
|
||||
|
||||
def add_entity(self, entity, platform=None, update_before_add=False):
|
||||
"""Add entity to component."""
|
||||
return run_coroutine_threadsafe(
|
||||
self.async_add_entity(entity, platform, update_before_add),
|
||||
self.hass.loop
|
||||
).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_add_entity(self, entity, platform=None, update_before_add=False):
|
||||
"""Add entity to component.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if entity is None or entity in self.entities.values():
|
||||
return False
|
||||
|
||||
entity.hass = self.hass
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
if update_before_add:
|
||||
try:
|
||||
yield from entity.async_device_update(warning=False)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception("Error on device update!")
|
||||
return False
|
||||
|
||||
# Write entity_id to entity
|
||||
if getattr(entity, 'entity_id', None) is None:
|
||||
object_id = entity.name or DEVICE_DEFAULT_NAME
|
||||
|
||||
if platform is not None and platform.entity_namespace is not None:
|
||||
object_id = '{} {}'.format(platform.entity_namespace,
|
||||
object_id)
|
||||
|
||||
entity.entity_id = async_generate_entity_id(
|
||||
self.entity_id_format, object_id,
|
||||
self.entities.keys())
|
||||
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if entity.entity_id in self.entities:
|
||||
raise HomeAssistantError(
|
||||
'Entity id already exists: {}'.format(entity.entity_id))
|
||||
elif not valid_entity_id(entity.entity_id):
|
||||
raise HomeAssistantError(
|
||||
'Invalid entity id: {}'.format(entity.entity_id))
|
||||
|
||||
self.entities[entity.entity_id] = entity
|
||||
|
||||
if hasattr(entity, 'async_added_to_hass'):
|
||||
yield from entity.async_added_to_hass()
|
||||
|
||||
yield from entity.async_update_ha_state()
|
||||
|
||||
return True
|
||||
|
||||
def update_group(self):
|
||||
"""Set up and/or update component group."""
|
||||
run_callback_threadsafe(
|
||||
self.hass.loop, self.async_update_group).result()
|
||||
|
||||
@callback
|
||||
def async_update_group(self):
|
||||
"""Set up and/or update component group.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if self.group_name is not None:
|
||||
ids = sorted(self.entities,
|
||||
key=lambda x: self.entities[x].name or x)
|
||||
group = get_component('group')
|
||||
group.async_set_group(
|
||||
self.hass, slugify(self.group_name), name=self.group_name,
|
||||
visible=False, entity_ids=ids
|
||||
)
|
||||
if self.group_name is None:
|
||||
return
|
||||
|
||||
ids = [entity.entity_id for entity in
|
||||
sorted(self.entities,
|
||||
key=lambda entity: entity.name or entity.entity_id)]
|
||||
|
||||
self.hass.components.group.async_set_group(
|
||||
slugify(self.group_name), name=self.group_name,
|
||||
visible=False, entity_ids=ids
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""Remove entities and reset the entity component to initial values."""
|
||||
|
@ -287,12 +248,17 @@ class EntityComponent(object):
|
|||
self._platforms = {
|
||||
'core': self._platforms['core']
|
||||
}
|
||||
self.entities = {}
|
||||
self.config = None
|
||||
|
||||
if self.group_name is not None:
|
||||
group = get_component('group')
|
||||
group.async_remove(self.hass, slugify(self.group_name))
|
||||
self.hass.components.group.async_remove(slugify(self.group_name))
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove_entity(self, entity_id):
|
||||
"""Remove an entity managed by one of the platforms."""
|
||||
for platform in self._platforms.values():
|
||||
if entity_id in platform.entities:
|
||||
yield from platform.async_remove_entity(entity_id)
|
||||
|
||||
def prepare_reload(self):
|
||||
"""Prepare reloading this entity component."""
|
||||
|
@ -323,7 +289,7 @@ class EntityComponent(object):
|
|||
|
||||
|
||||
class EntityPlatform(object):
|
||||
"""Keep track of entities for a single platform and stay in loop."""
|
||||
"""Manage the entities for a single platform."""
|
||||
|
||||
def __init__(self, component, platform, scan_interval, parallel_updates,
|
||||
entity_namespace):
|
||||
|
@ -333,7 +299,7 @@ class EntityPlatform(object):
|
|||
self.scan_interval = scan_interval
|
||||
self.parallel_updates = None
|
||||
self.entity_namespace = entity_namespace
|
||||
self.platform_entities = []
|
||||
self.entities = {}
|
||||
self._tasks = []
|
||||
self._async_unsub_polling = None
|
||||
self._process_updates = asyncio.Lock(loop=component.hass.loop)
|
||||
|
@ -391,40 +357,88 @@ class EntityPlatform(object):
|
|||
if not new_entities:
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_process_entity(new_entity):
|
||||
"""Add entities to StateMachine."""
|
||||
new_entity.parallel_updates = self.parallel_updates
|
||||
ret = yield from self.component.async_add_entity(
|
||||
new_entity, self, update_before_add=update_before_add
|
||||
)
|
||||
if ret:
|
||||
self.platform_entities.append(new_entity)
|
||||
component_entities = set(entity.entity_id for entity
|
||||
in self.component.entities)
|
||||
|
||||
tasks = [async_process_entity(entity) for entity in new_entities]
|
||||
tasks = [
|
||||
self._async_add_entity(entity, update_before_add,
|
||||
component_entities)
|
||||
for entity in new_entities]
|
||||
|
||||
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
|
||||
self.component.async_update_group()
|
||||
|
||||
if self._async_unsub_polling is not None or \
|
||||
not any(entity.should_poll for entity
|
||||
in self.platform_entities):
|
||||
in self.entities.values()):
|
||||
return
|
||||
|
||||
self._async_unsub_polling = async_track_time_interval(
|
||||
self.component.hass, self._update_entity_states, self.scan_interval
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_add_entity(self, entity, update_before_add, component_entities):
|
||||
"""Helper method to add an entity to the platform."""
|
||||
if entity is None:
|
||||
raise ValueError('Entity cannot be None')
|
||||
|
||||
# Do nothing if entity has already been added based on unique id.
|
||||
if entity in self.component.entities:
|
||||
return
|
||||
|
||||
entity.hass = self.component.hass
|
||||
entity.platform = self
|
||||
entity.parallel_updates = self.parallel_updates
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
if update_before_add:
|
||||
try:
|
||||
yield from entity.async_device_update(warning=False)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.component.logger.exception(
|
||||
"%s: Error on device update!", self.platform)
|
||||
return
|
||||
|
||||
# Write entity_id to entity
|
||||
if getattr(entity, 'entity_id', None) is None:
|
||||
object_id = entity.name or DEVICE_DEFAULT_NAME
|
||||
|
||||
if self.entity_namespace is not None:
|
||||
object_id = '{} {}'.format(self.entity_namespace,
|
||||
object_id)
|
||||
|
||||
entity.entity_id = async_generate_entity_id(
|
||||
self.component.entity_id_format, object_id,
|
||||
component_entities)
|
||||
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if not valid_entity_id(entity.entity_id):
|
||||
raise HomeAssistantError(
|
||||
'Invalid entity id: {}'.format(entity.entity_id))
|
||||
elif entity.entity_id in component_entities:
|
||||
raise HomeAssistantError(
|
||||
'Entity id already exists: {}'.format(entity.entity_id))
|
||||
|
||||
self.entities[entity.entity_id] = entity
|
||||
component_entities.add(entity.entity_id)
|
||||
|
||||
if hasattr(entity, 'async_added_to_hass'):
|
||||
yield from entity.async_added_to_hass()
|
||||
|
||||
yield from entity.async_update_ha_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_reset(self):
|
||||
"""Remove all entities and reset data.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if not self.platform_entities:
|
||||
if not self.entities:
|
||||
return
|
||||
|
||||
tasks = [entity.async_remove() for entity in self.platform_entities]
|
||||
tasks = [self._async_remove_entity(entity_id)
|
||||
for entity_id in self.entities]
|
||||
|
||||
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
|
||||
|
||||
|
@ -432,6 +446,28 @@ class EntityPlatform(object):
|
|||
self._async_unsub_polling()
|
||||
self._async_unsub_polling = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove_entity(self, entity_id):
|
||||
"""Remove entity id from platform."""
|
||||
yield from self._async_remove_entity(entity_id)
|
||||
|
||||
# Clean up polling job if no longer needed
|
||||
if (self._async_unsub_polling is not None and
|
||||
not any(entity.should_poll for entity
|
||||
in self.entities.values())):
|
||||
self._async_unsub_polling()
|
||||
self._async_unsub_polling = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_remove_entity(self, entity_id):
|
||||
"""Remove entity id from platform."""
|
||||
entity = self.entities.pop(entity_id)
|
||||
|
||||
if hasattr(entity, 'async_will_remove_from_hass'):
|
||||
yield from entity.async_will_remove_from_hass()
|
||||
|
||||
self.component.hass.states.async_remove(entity_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _update_entity_states(self, now):
|
||||
"""Update the states of all the polling entities.
|
||||
|
@ -450,7 +486,7 @@ class EntityPlatform(object):
|
|||
|
||||
with (yield from self._process_updates):
|
||||
tasks = []
|
||||
for entity in self.platform_entities:
|
||||
for entity in self.entities.values():
|
||||
if not entity.should_poll:
|
||||
continue
|
||||
tasks.append(entity.async_update_ha_state(True))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue