120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
"""Helper class to implement include/exclude of entities and domains."""
|
|
from typing import Callable, Dict, List
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.core import split_entity_id
|
|
from homeassistant.helpers import config_validation as cv
|
|
|
|
CONF_INCLUDE_DOMAINS = "include_domains"
|
|
CONF_INCLUDE_ENTITIES = "include_entities"
|
|
CONF_EXCLUDE_DOMAINS = "exclude_domains"
|
|
CONF_EXCLUDE_ENTITIES = "exclude_entities"
|
|
|
|
|
|
def _convert_filter(config: Dict[str, List[str]]) -> Callable[[str], bool]:
|
|
filt = generate_filter(
|
|
config[CONF_INCLUDE_DOMAINS],
|
|
config[CONF_INCLUDE_ENTITIES],
|
|
config[CONF_EXCLUDE_DOMAINS],
|
|
config[CONF_EXCLUDE_ENTITIES],
|
|
)
|
|
setattr(filt, "config", config)
|
|
setattr(filt, "empty_filter", sum(len(val) for val in config.values()) == 0)
|
|
return filt
|
|
|
|
|
|
FILTER_SCHEMA = vol.All(
|
|
vol.Schema(
|
|
{
|
|
vol.Optional(CONF_EXCLUDE_DOMAINS, default=[]): vol.All(
|
|
cv.ensure_list, [cv.string]
|
|
),
|
|
vol.Optional(CONF_EXCLUDE_ENTITIES, default=[]): cv.entity_ids,
|
|
vol.Optional(CONF_INCLUDE_DOMAINS, default=[]): vol.All(
|
|
cv.ensure_list, [cv.string]
|
|
),
|
|
vol.Optional(CONF_INCLUDE_ENTITIES, default=[]): cv.entity_ids,
|
|
}
|
|
),
|
|
_convert_filter,
|
|
)
|
|
|
|
|
|
def generate_filter(
|
|
include_domains: List[str],
|
|
include_entities: List[str],
|
|
exclude_domains: List[str],
|
|
exclude_entities: List[str],
|
|
) -> Callable[[str], bool]:
|
|
"""Return a function that will filter entities based on the args."""
|
|
include_d = set(include_domains)
|
|
include_e = set(include_entities)
|
|
exclude_d = set(exclude_domains)
|
|
exclude_e = set(exclude_entities)
|
|
|
|
have_exclude = bool(exclude_e or exclude_d)
|
|
have_include = bool(include_e or include_d)
|
|
|
|
# Case 1 - no includes or excludes - pass all entities
|
|
if not have_include and not have_exclude:
|
|
return lambda entity_id: True
|
|
|
|
# Case 2 - includes, no excludes - only include specified entities
|
|
if have_include and not have_exclude:
|
|
|
|
def entity_filter_2(entity_id: str) -> bool:
|
|
"""Return filter function for case 2."""
|
|
domain = split_entity_id(entity_id)[0]
|
|
return entity_id in include_e or domain in include_d
|
|
|
|
return entity_filter_2
|
|
|
|
# Case 3 - excludes, no includes - only exclude specified entities
|
|
if not have_include and have_exclude:
|
|
|
|
def entity_filter_3(entity_id: str) -> bool:
|
|
"""Return filter function for case 3."""
|
|
domain = split_entity_id(entity_id)[0]
|
|
return entity_id not in exclude_e and domain not in exclude_d
|
|
|
|
return entity_filter_3
|
|
|
|
# Case 4 - both includes and excludes specified
|
|
# Case 4a - include domain specified
|
|
# - if domain is included, pass if entity not excluded
|
|
# - if domain is not included, pass if entity is included
|
|
# note: if both include and exclude domains specified,
|
|
# the exclude domains are ignored
|
|
if include_d:
|
|
|
|
def entity_filter_4a(entity_id: str) -> bool:
|
|
"""Return filter function for case 4a."""
|
|
domain = split_entity_id(entity_id)[0]
|
|
if domain in include_d:
|
|
return entity_id not in exclude_e
|
|
return entity_id in include_e
|
|
|
|
return entity_filter_4a
|
|
|
|
# Case 4b - exclude domain specified
|
|
# - if domain is excluded, pass if entity is included
|
|
# - if domain is not excluded, pass if entity not excluded
|
|
if exclude_d:
|
|
|
|
def entity_filter_4b(entity_id: str) -> bool:
|
|
"""Return filter function for case 4b."""
|
|
domain = split_entity_id(entity_id)[0]
|
|
if domain in exclude_d:
|
|
return entity_id in include_e
|
|
return entity_id not in exclude_e
|
|
|
|
return entity_filter_4b
|
|
|
|
# Case 4c - neither include or exclude domain specified
|
|
# - Only pass if entity is included. Ignore entity excludes.
|
|
def entity_filter_4c(entity_id: str) -> bool:
|
|
"""Return filter function for case 4c."""
|
|
return entity_id in include_e
|
|
|
|
return entity_filter_4c
|