"""Tests for the Somfy config flow.""" import asyncio import logging import time from unittest.mock import patch import pytest from homeassistant import config_entries, data_entry_flow, setup from homeassistant.helpers import config_entry_oauth2_flow from tests.common import MockConfigEntry, mock_platform TEST_DOMAIN = "oauth2_test" CLIENT_SECRET = "5678" CLIENT_ID = "1234" REFRESH_TOKEN = "mock-refresh-token" ACCESS_TOKEN_1 = "mock-access-token-1" ACCESS_TOKEN_2 = "mock-access-token-2" AUTHORIZE_URL = "https://example.como/auth/authorize" TOKEN_URL = "https://example.como/auth/token" @pytest.fixture async def local_impl(hass): """Local implementation.""" assert await setup.async_setup_component(hass, "http", {}) return config_entry_oauth2_flow.LocalOAuth2Implementation( hass, TEST_DOMAIN, CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL ) @pytest.fixture def flow_handler(hass): """Return a registered config flow.""" mock_platform(hass, f"{TEST_DOMAIN}.config_flow") class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler): """Test flow handler.""" DOMAIN = TEST_DOMAIN @property def logger(self) -> logging.Logger: """Return logger.""" return logging.getLogger(__name__) @property def extra_authorize_data(self) -> dict: """Extra data that needs to be appended to the authorize url.""" return {"scope": "read write"} with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}): yield TestFlowHandler class MockOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation): """Mock implementation for testing.""" @property def name(self) -> str: """Name of the implementation.""" return "Mock" @property def domain(self) -> str: """Domain that is providing the implementation.""" return "test" async def async_generate_authorize_url(self, flow_id: str) -> str: """Generate a url for the user to authorize.""" return "http://example.com/auth" async def async_resolve_external_data(self, external_data) -> dict: """Resolve external data to tokens.""" return external_data async def _async_refresh_token(self, token: dict) -> dict: """Refresh a token.""" raise NotImplementedError() def test_inherit_enforces_domain_set(): """Test we enforce setting DOMAIN.""" class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler): """Test flow handler.""" @property def logger(self) -> logging.Logger: """Return logger.""" return logging.getLogger(__name__) with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}): with pytest.raises(TypeError): TestFlowHandler() async def test_abort_if_no_implementation(hass, flow_handler): """Check flow abort when no implementations.""" flow = flow_handler() flow.hass = hass result = await flow.async_step_user() assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["reason"] == "missing_configuration" async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl): """Check timeout generating authorization url.""" flow_handler.async_register_implementation(hass, local_impl) flow = flow_handler() flow.hass = hass with patch.object( local_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError ): result = await flow.async_step_user() assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["reason"] == "authorize_url_timeout" async def test_full_flow( hass, flow_handler, local_impl, aiohttp_client, aioclient_mock ): """Check full flow.""" hass.config.api.base_url = "https://example.com" flow_handler.async_register_implementation(hass, local_impl) config_entry_oauth2_flow.async_register_implementation( hass, TEST_DOMAIN, MockOAuth2Implementation() ) result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "pick_implementation" # Pick implementation result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={"implementation": TEST_DOMAIN} ) state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["url"] == ( f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}" "&redirect_uri=https://example.com/auth/external/callback" f"&state={state}&scope=read+write" ) client = await aiohttp_client(hass.http.app) resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") assert resp.status == 200 assert resp.headers["content-type"] == "text/html; charset=utf-8" aioclient_mock.post( TOKEN_URL, json={ "refresh_token": REFRESH_TOKEN, "access_token": ACCESS_TOKEN_1, "type": "bearer", "expires_in": 60, }, ) result = await hass.config_entries.flow.async_configure(result["flow_id"]) assert result["data"]["auth_implementation"] == TEST_DOMAIN result["data"]["token"].pop("expires_at") assert result["data"]["token"] == { "refresh_token": REFRESH_TOKEN, "access_token": ACCESS_TOKEN_1, "type": "bearer", "expires_in": 60, } entry = hass.config_entries.async_entries(TEST_DOMAIN)[0] assert ( await config_entry_oauth2_flow.async_get_config_entry_implementation( hass, entry ) is local_impl ) async def test_local_refresh_token(hass, local_impl, aioclient_mock): """Test we can refresh token.""" aioclient_mock.post( TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100} ) new_tokens = await local_impl.async_refresh_token( { "refresh_token": REFRESH_TOKEN, "access_token": ACCESS_TOKEN_1, "type": "bearer", "expires_in": 60, } ) new_tokens.pop("expires_at") assert new_tokens == { "refresh_token": REFRESH_TOKEN, "access_token": ACCESS_TOKEN_2, "type": "bearer", "expires_in": 100, } assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == { "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET, "grant_type": "refresh_token", "refresh_token": REFRESH_TOKEN, } async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock): """Test the OAuth2 session helper.""" flow_handler.async_register_implementation(hass, local_impl) aioclient_mock.post( TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100} ) aioclient_mock.post("https://example.com", status=201) config_entry = MockConfigEntry( domain=TEST_DOMAIN, data={ "auth_implementation": TEST_DOMAIN, "token": { "refresh_token": REFRESH_TOKEN, "access_token": ACCESS_TOKEN_1, "expires_in": 10, "expires_at": 0, # Forces a refresh, "token_type": "bearer", "random_other_data": "should_stay", }, }, ) now = time.time() session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl) resp = await session.async_request("post", "https://example.com") assert resp.status == 201 # Refresh token, make request assert len(aioclient_mock.mock_calls) == 2 assert ( aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}" ) assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2 assert config_entry.data["token"]["expires_in"] == 100 assert config_entry.data["token"]["random_other_data"] == "should_stay" assert round(config_entry.data["token"]["expires_at"] - now) == 100 async def test_implementation_provider(hass, local_impl): """Test providing an implementation provider.""" assert ( await config_entry_oauth2_flow.async_get_implementations(hass, TEST_DOMAIN) == {} ) mock_domain_with_impl = "some_domain" config_entry_oauth2_flow.async_register_implementation( hass, mock_domain_with_impl, local_impl ) assert await config_entry_oauth2_flow.async_get_implementations( hass, mock_domain_with_impl ) == {TEST_DOMAIN: local_impl} provider_source = {} async def async_provide_implementation(hass, domain): """Mock implementation provider.""" return provider_source.get(domain) config_entry_oauth2_flow.async_add_implementation_provider( hass, "cloud", async_provide_implementation ) assert await config_entry_oauth2_flow.async_get_implementations( hass, mock_domain_with_impl ) == {TEST_DOMAIN: local_impl} provider_source[ mock_domain_with_impl ] = config_entry_oauth2_flow.LocalOAuth2Implementation( hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL ) assert await config_entry_oauth2_flow.async_get_implementations( hass, mock_domain_with_impl ) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]}