Introduce new OAuth2 config flow helper (#27727)
* Refactor the Somfy auth implementation * Typing * Migrate Somfy to OAuth2 flow helper * Add tests * Add more tests * Fix tests * Fix type error * More tests * Remove side effect from constructor * implementation -> auth_implementation * Make get_implementation async * Minor cleanup + Allow picking implementations. * Add support for extra authorize data
This commit is contained in:
parent
6157be23dc
commit
b6c26cb363
15 changed files with 900 additions and 214 deletions
420
homeassistant/helpers/config_entry_oauth2_flow.py
Normal file
420
homeassistant/helpers/config_entry_oauth2_flow.py
Normal file
|
@ -0,0 +1,420 @@
|
|||
"""Config Flow using OAuth2.
|
||||
|
||||
This module exists of the following parts:
|
||||
- OAuth2 config flow which supports multiple OAuth2 implementations
|
||||
- OAuth2 implementation that works with local provided client ID/secret
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
from abc import ABCMeta, ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Optional, Any, Dict, cast
|
||||
import time
|
||||
|
||||
import async_timeout
|
||||
from aiohttp import web, client
|
||||
import jwt
|
||||
import voluptuous as vol
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth.util import generate_secret
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
|
||||
from .aiohttp_client import async_get_clientsession
|
||||
|
||||
|
||||
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
||||
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
||||
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||
|
||||
|
||||
class AbstractOAuth2Implementation(ABC):
|
||||
"""Base class to abstract OAuth2 authentication."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Name of the implementation."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def domain(self) -> str:
|
||||
"""Domain that is providing the implementation."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize.
|
||||
|
||||
This step is called when a config flow is initialized. It should redirect the
|
||||
user to the vendor website where they can authorize Home Assistant.
|
||||
|
||||
The implementation is responsible to get notified when the user is authorized
|
||||
and pass this to the specified config flow. Do as little work as possible once
|
||||
notified. You can do the work inside async_resolve_external_data. This will
|
||||
give the best UX.
|
||||
|
||||
Pass external data in with:
|
||||
|
||||
```python
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flow_id=flow_id, user_input=external_data
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||
"""Resolve external data to tokens.
|
||||
|
||||
Turn the data that the implementation passed to the config flow as external
|
||||
step data into tokens. These tokens will be stored as 'token' in the
|
||||
config entry data.
|
||||
"""
|
||||
|
||||
async def async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh a token and update expires info."""
|
||||
new_token = await self._async_refresh_token(token)
|
||||
new_token["expires_at"] = time.time() + new_token["expires_in"]
|
||||
return new_token
|
||||
|
||||
@abstractmethod
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh a token."""
|
||||
|
||||
|
||||
class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
"""Local OAuth2 implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
):
|
||||
"""Initialize local auth implementation."""
|
||||
self.hass = hass
|
||||
self._domain = domain
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the implementation."""
|
||||
return "Configuration.yaml"
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
"""Domain providing the implementation."""
|
||||
return self._domain
|
||||
|
||||
@property
|
||||
def redirect_uri(self) -> str:
|
||||
"""Return the redirect uri."""
|
||||
return f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}" # type: ignore
|
||||
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize."""
|
||||
return str(
|
||||
URL(self.authorize_url).with_query(
|
||||
{
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": _encode_jwt(self.hass, {"flow_id": flow_id}),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||
"""Resolve the authorization code to tokens."""
|
||||
return await self._token_request(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": external_data,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
)
|
||||
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh tokens."""
|
||||
new_token = await self._token_request(
|
||||
{
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": self.client_id,
|
||||
"refresh_token": token["refresh_token"],
|
||||
}
|
||||
)
|
||||
return {**token, **new_token}
|
||||
|
||||
async def _token_request(self, data: dict) -> dict:
|
||||
"""Make a token request."""
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
data["client_id"] = self.client_id
|
||||
|
||||
if self.client_secret is not None:
|
||||
data["client_secret"] = self.client_secret
|
||||
|
||||
resp = await session.post(self.token_url, data=data)
|
||||
resp.raise_for_status()
|
||||
return cast(dict, await resp.json())
|
||||
|
||||
|
||||
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
||||
"""Handle a config flow."""
|
||||
|
||||
DOMAIN = ""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_UNKNOWN
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Instantiate config flow."""
|
||||
if self.DOMAIN == "":
|
||||
raise TypeError(
|
||||
f"Can't instantiate class {self.__class__.__name__} without DOMAIN being set"
|
||||
)
|
||||
|
||||
self.external_data: Any = None
|
||||
self.flow_impl: AbstractOAuth2Implementation = None # type: ignore
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
|
||||
@property
|
||||
def extra_authorize_data(self) -> dict:
|
||||
"""Extra data that needs to be appended to the authorize url."""
|
||||
return {}
|
||||
|
||||
async def async_step_pick_implementation(self, user_input: dict = None) -> dict:
|
||||
"""Handle a flow start."""
|
||||
assert self.hass
|
||||
implementations = await async_get_implementations(self.hass, self.DOMAIN)
|
||||
|
||||
if user_input is not None:
|
||||
self.flow_impl = implementations[user_input["implementation"]]
|
||||
return await self.async_step_auth()
|
||||
|
||||
if not implementations:
|
||||
return self.async_abort(reason="missing_configuration")
|
||||
|
||||
if len(implementations) == 1:
|
||||
# Pick first implementation as we have only one.
|
||||
self.flow_impl = list(implementations.values())[0]
|
||||
return await self.async_step_auth()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="pick_implementation",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
"implementation", default=list(implementations.keys())[0]
|
||||
): vol.In({key: impl.name for key, impl in implementations.items()})
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
async def async_step_auth(self, user_input: dict = None) -> dict:
|
||||
"""Create an entry for auth."""
|
||||
# Flow has been triggered by external data
|
||||
if user_input:
|
||||
self.external_data = user_input
|
||||
return self.async_external_step_done(next_step_id="creation")
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(10):
|
||||
url = await self.flow_impl.async_generate_authorize_url(self.flow_id)
|
||||
except asyncio.TimeoutError:
|
||||
return self.async_abort(reason="authorize_url_timeout")
|
||||
|
||||
url = str(URL(url).update_query(self.extra_authorize_data))
|
||||
|
||||
return self.async_external_step(step_id="auth", url=url)
|
||||
|
||||
async def async_step_creation(self, user_input: dict = None) -> dict:
|
||||
"""Create config entry from external data."""
|
||||
token = await self.flow_impl.async_resolve_external_data(self.external_data)
|
||||
token["expires_at"] = time.time() + token["expires_in"]
|
||||
|
||||
self.logger.info("Successfully authenticated")
|
||||
|
||||
return await self.async_oauth_create_entry(
|
||||
{"auth_implementation": self.flow_impl.domain, "token": token}
|
||||
)
|
||||
|
||||
async def async_oauth_create_entry(self, data: dict) -> dict:
|
||||
"""Create an entry for the flow.
|
||||
|
||||
Ok to override if you want to fetch extra info or even add another step.
|
||||
"""
|
||||
return self.async_create_entry(title=self.flow_impl.name, data=data)
|
||||
|
||||
async_step_user = async_step_pick_implementation
|
||||
async_step_ssdp = async_step_pick_implementation
|
||||
async_step_zeroconf = async_step_pick_implementation
|
||||
async_step_homekit = async_step_pick_implementation
|
||||
|
||||
@classmethod
|
||||
def async_register_implementation(
|
||||
cls, hass: HomeAssistant, local_impl: LocalOAuth2Implementation
|
||||
) -> None:
|
||||
"""Register a local implementation."""
|
||||
async_register_implementation(hass, cls.DOMAIN, local_impl)
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_implementation(
|
||||
hass: HomeAssistant, domain: str, implementation: AbstractOAuth2Implementation
|
||||
) -> None:
|
||||
"""Register an OAuth2 flow implementation for an integration."""
|
||||
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
|
||||
DATA_VIEW_REGISTERED, False
|
||||
):
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
|
||||
hass.data[DATA_VIEW_REGISTERED] = True
|
||||
|
||||
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
|
||||
implementations.setdefault(domain, {})[implementation.domain] = implementation
|
||||
|
||||
|
||||
async def async_get_implementations(
|
||||
hass: HomeAssistant, domain: str
|
||||
) -> Dict[str, AbstractOAuth2Implementation]:
|
||||
"""Return OAuth2 implementations for specified domain."""
|
||||
return cast(
|
||||
Dict[str, AbstractOAuth2Implementation],
|
||||
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
|
||||
)
|
||||
|
||||
|
||||
async def async_get_config_entry_implementation(
|
||||
hass: HomeAssistant, config_entry: config_entries.ConfigEntry
|
||||
) -> AbstractOAuth2Implementation:
|
||||
"""Return the implementation for this config entry."""
|
||||
implementations = await async_get_implementations(hass, config_entry.domain)
|
||||
implementation = implementations.get(config_entry.data["auth_implementation"])
|
||||
|
||||
if implementation is None:
|
||||
raise ValueError("Implementation not available")
|
||||
|
||||
return implementation
|
||||
|
||||
|
||||
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||
"""OAuth2 Authorization Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = AUTH_CALLBACK_PATH
|
||||
name = "auth:external:callback"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive authorization code."""
|
||||
if "code" not in request.query or "state" not in request.query:
|
||||
return web.Response(
|
||||
text=f"Missing code or state parameter in {request.url}"
|
||||
)
|
||||
|
||||
hass = request.app["hass"]
|
||||
|
||||
state = _decode_jwt(hass, request.query["state"])
|
||||
|
||||
if state is None:
|
||||
return web.Response(text=f"Invalid state")
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flow_id=state["flow_id"], user_input=request.query["code"]
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<script>window.close()</script>",
|
||||
)
|
||||
|
||||
|
||||
class OAuth2Session:
|
||||
"""Session to make requests authenticated with OAuth2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
implementation: AbstractOAuth2Implementation,
|
||||
):
|
||||
"""Initialize an OAuth2 session."""
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.implementation = implementation
|
||||
|
||||
async def async_ensure_token_valid(self) -> None:
|
||||
"""Ensure that the current token is valid."""
|
||||
token = self.config_entry.data["token"]
|
||||
|
||||
if token["expires_at"] > time.time():
|
||||
return
|
||||
|
||||
new_token = await self.implementation.async_refresh_token(token)
|
||||
|
||||
self.hass.config_entries.async_update_entry( # type: ignore
|
||||
self.config_entry, data={**self.config_entry.data, "token": new_token}
|
||||
)
|
||||
|
||||
async def async_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> client.ClientResponse:
|
||||
"""Make a request."""
|
||||
await self.async_ensure_token_valid()
|
||||
return await async_oauth2_request(
|
||||
self.hass, self.config_entry.data["token"], method, url, **kwargs
|
||||
)
|
||||
|
||||
|
||||
async def async_oauth2_request(
|
||||
hass: HomeAssistant, token: dict, method: str, url: str, **kwargs: Any
|
||||
) -> client.ClientResponse:
|
||||
"""Make an OAuth2 authenticated request.
|
||||
|
||||
This method will not refresh tokens. Use OAuth2 session for that.
|
||||
"""
|
||||
session = async_get_clientsession(hass)
|
||||
|
||||
return await session.request(
|
||||
method,
|
||||
url,
|
||||
**kwargs,
|
||||
headers={
|
||||
**kwargs.get("headers", {}),
|
||||
"authorization": f"Bearer {token['access_token']}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
|
||||
"""JWT encode data."""
|
||||
secret = hass.data.get(DATA_JWT_SECRET)
|
||||
|
||||
if secret is None:
|
||||
secret = hass.data[DATA_JWT_SECRET] = generate_secret()
|
||||
|
||||
return jwt.encode(data, secret, algorithm="HS256").decode()
|
||||
|
||||
|
||||
@callback
|
||||
def _decode_jwt(hass: HomeAssistant, encoded: str) -> Optional[dict]:
|
||||
"""JWT encode data."""
|
||||
secret = cast(str, hass.data.get(DATA_JWT_SECRET))
|
||||
|
||||
try:
|
||||
return jwt.decode(encoded, secret, algorithms=["HS256"])
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue