From d265160f48a8838081606958c880c5aeedd8f314 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 1 Sep 2020 09:01:43 -0700 Subject: [PATCH] Add redirect_uri argument to InteractiveBrowserCredential --- sdk/identity/azure-identity/CHANGELOG.md | 3 + .../azure/identity/_credentials/browser.py | 33 +++++--- .../_internal/auth_code_redirect_handler.py | 24 +++--- .../tests/test_browser_credential.py | 83 +++++++++++-------- 4 files changed, 84 insertions(+), 59 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 1de29521c222..0a87d5108a23 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -13,6 +13,9 @@ (`azure.identity.aio.CertificateCredential`) will support this in a future version. ([#10816](https://github.com/Azure/azure-sdk-for-python/issues/10816)) +- `InteractiveBrowserCredential` keyword argument `redirect_uri` enables + authentication with a user-specified application having a custom redirect URI + ([#13344](https://github.com/Azure/azure-sdk-for-python/issues/13344)) ## 1.4.0 (2020-08-10) ### Added diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py index cf860f5b39f1..751e30fca141 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py @@ -36,6 +36,10 @@ class InteractiveBrowserCredential(InteractiveCredential): authenticate work or school accounts. :keyword str client_id: Client ID of the Azure Active Directory application users will sign in to. If unspecified, the Azure CLI's ID will be used. + :keyword str redirect_uri: a redirect URI for the application identified by `client_id` as configured in Azure + Active Directory, for example "http://localhost:8400". This is only required when passing a value for + `client_id`, and must match a redirect URI in the application's registration. The credential must be able to + bind a socket to this URI. :keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate` :keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise :class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False. @@ -48,8 +52,9 @@ class InteractiveBrowserCredential(InteractiveCredential): def __init__(self, **kwargs): # type: (**Any) -> None + self._redirect_uri = kwargs.pop("redirect_uri", None) self._timeout = kwargs.pop("timeout", 300) - self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking + self._server_class = kwargs.pop("_server_class", AuthCodeRedirectServer) client_id = kwargs.pop("client_id", AZURE_CLI_CLIENT_ID) super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs) @@ -57,17 +62,24 @@ def __init__(self, **kwargs): def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> dict - # start an HTTP server on localhost to receive the redirect - redirect_uri = None - for port in range(8400, 9000): + # start an HTTP server to receive the redirect + server = None + redirect_uri = self._redirect_uri + if redirect_uri: try: - server = self._server_class(port, timeout=self._timeout) - redirect_uri = "http://localhost:{}".format(port) - break + server = self._server_class(redirect_uri, timeout=self._timeout) except socket.error: - continue # keep looking for an open port - - if not redirect_uri: + raise CredentialUnavailableError(message="Couldn't start an HTTP server on " + redirect_uri) + else: + for port in range(8400, 9000): + try: + redirect_uri = "http://localhost:{}".format(port) + server = self._server_class(redirect_uri, timeout=self._timeout) + break + except socket.error: + continue # keep looking for an open port + + if not server: raise CredentialUnavailableError(message="Couldn't start an HTTP server on localhost") # get the url the user must visit to authenticate @@ -93,7 +105,6 @@ def _request_token(self, *scopes, **kwargs): code = self._parse_response(request_state, response) return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs) - @staticmethod def _parse_response(request_state, response): # type: (str, Mapping[str, Any]) -> List[str] diff --git a/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py b/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py index 9463e3c6412b..67906fcc568b 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py @@ -2,23 +2,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from typing import TYPE_CHECKING -except ImportError: - TYPE_CHECKING = False +from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Any, Mapping, Optional +from six.moves.urllib_parse import parse_qs, urlparse try: from http.server import HTTPServer, BaseHTTPRequestHandler except ImportError: from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler # type: ignore -try: - from urllib.parse import parse_qs -except ImportError: - from urlparse import parse_qs # type: ignore +if TYPE_CHECKING: + # pylint:disable=ungrouped-imports + from typing import Any, Mapping class AuthCodeRedirectHandler(BaseHTTPRequestHandler): @@ -46,13 +41,14 @@ def log_message(self, format, *args): # pylint: disable=redefined-builtin,unuse class AuthCodeRedirectServer(HTTPServer): - """HTTP server that listens on localhost for the redirect request following an authorization code authentication""" + """HTTP server that listens for the redirect request following an authorization code authentication""" query_params = {} # type: Mapping[str, Any] - def __init__(self, port, timeout): - # type: (int, int) -> None - HTTPServer.__init__(self, ("localhost", port), AuthCodeRedirectHandler) + def __init__(self, uri, timeout): + # type: (str, int) -> None + parsed = urlparse(uri) + HTTPServer.__init__(self, (parsed.hostname, parsed.port), AuthCodeRedirectHandler) self.timeout = timeout def wait_for_redirect(self): diff --git a/sdk/identity/azure-identity/tests/test_browser_credential.py b/sdk/identity/azure-identity/tests/test_browser_credential.py index 25ea77f71b66..cedcdfa61474 100644 --- a/sdk/identity/azure-identity/tests/test_browser_credential.py +++ b/sdk/identity/azure-identity/tests/test_browser_credential.py @@ -22,14 +22,15 @@ build_id_token, get_discovery_response, mock_response, + msal_validating_transport, Request, validating_transport, ) try: - from unittest.mock import Mock, patch + from unittest.mock import ANY, Mock, patch except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore + from mock import ANY, Mock, patch # type: ignore WEBBROWSER_OPEN = InteractiveBrowserCredential.__module__ + ".webbrowser.open" @@ -77,7 +78,7 @@ def test_authenticate(): _cache=TokenCache(), authority=environment, client_id=client_id, - server_class=server_class, + _server_class=server_class, tenant_id=tenant_id, transport=transport, ) @@ -126,7 +127,7 @@ def test_policies_configurable(): server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( - policies=[policy], client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache() + policies=[policy], client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache() ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): @@ -152,7 +153,7 @@ def test_user_agent(): server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( - client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache() + client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache() ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): @@ -160,7 +161,8 @@ def test_user_agent(): @patch("azure.identity._credentials.browser.webbrowser.open") -def test_interactive_credential(mock_open): +@pytest.mark.parametrize("redirect_url", ("https://localhost:8042", None)) +def test_interactive_credential(mock_open, redirect_url): mock_open.side_effect = _validate_auth_request_url oauth_state = "state" client_id = "client-id" @@ -171,17 +173,15 @@ def test_interactive_credential(mock_open): tenant_id = "tenant_id" endpoint = "https://{}/{}".format(authority, tenant_id) - discovery_response = get_discovery_response(endpoint=endpoint) - transport = validating_transport( - requests=[Request(url_substring=endpoint)] * 3 + transport = msal_validating_transport( + endpoint="https://{}/{}".format(authority, tenant_id), + requests=[Request(url_substring=endpoint)] + [ Request( authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token} ) ], responses=[ - discovery_response, # instance discovery - discovery_response, # tenant discovery mock_response( json_payload=build_aad_response( access_token=expected_token, @@ -203,16 +203,18 @@ def test_interactive_credential(mock_open): auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) - credential = InteractiveBrowserCredential( - authority=authority, - tenant_id=tenant_id, - client_id=client_id, - server_class=server_class, - transport=transport, - instance_discovery=False, - validate_authority=False, - _cache=TokenCache(), - ) + args = { + "authority": authority, + "tenant_id": tenant_id, + "client_id": client_id, + "transport": transport, + "_cache": TokenCache(), + "_server_class": server_class, + } + if redirect_url: # avoid passing redirect_url=None + args["redirect_uri"] = redirect_url + + credential = InteractiveBrowserCredential(**args) # The credential's auth code request includes a uuid which must be included in the redirect. Patching to # set the uuid requires less code here than a proper mock server. @@ -220,20 +222,19 @@ def test_interactive_credential(mock_open): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 + assert server_class.call_count == 1 + + if redirect_url: + server_class.assert_called_once_with(redirect_url, timeout=ANY) # token should be cached, get_token shouldn't prompt again token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 - - # As of MSAL 1.0.0, applications build a new client every time they redeem a refresh token. - # Here we patch the private method they use for the sake of test coverage. - # TODO: this will probably break when this MSAL behavior changes - app = credential._get_app() - app._build_client = lambda *_: app.client # pylint:disable=protected-access - now = time.time() + assert server_class.call_count == 1 # expired access token -> credential should use refresh token instead of prompting again + now = time.time() with patch("time.time", lambda: now + expires_in): token = credential.get_token("scope") assert token.token == expected_token @@ -259,7 +260,7 @@ def test_interactive_credential_timeout(): credential = InteractiveBrowserCredential( client_id="guid", - server_class=server_class, + _server_class=server_class, timeout=timeout, transport=transport, instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request @@ -277,7 +278,8 @@ def test_redirect_server(): for _ in range(4): try: port = random.randint(1024, 65535) - server = AuthCodeRedirectServer(port, timeout=10) + url = "http://127.0.0.1:{}".format(port) + server = AuthCodeRedirectServer(url, timeout=10) break except socket.error: continue # keep looking for an open port @@ -293,8 +295,7 @@ def test_redirect_server(): thread.start() # send a request, verify the server exposes the query - url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec - response = urllib.request.urlopen(url) # nosec + response = urllib.request.urlopen(url + "?{}={}".format(expected_param, expected_value)) # nosec assert response.code == 200 assert server.query_params[expected_param] == [expected_value] @@ -304,7 +305,7 @@ def test_redirect_server(): def test_no_browser(): transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2) credential = InteractiveBrowserCredential( - client_id="client-id", server_class=Mock(), transport=transport, _cache=TokenCache() + client_id="client-id", _server_class=Mock(), transport=transport, _cache=TokenCache() ) with pytest.raises(ClientAuthenticationError, match=r".*browser.*"): credential.get_token("scope") @@ -313,11 +314,25 @@ def test_no_browser(): def test_cannot_bind_port(): """get_token should raise CredentialUnavailableError when the redirect listener can't bind a port""" - credential = InteractiveBrowserCredential(server_class=Mock(side_effect=socket.error)) + credential = InteractiveBrowserCredential(_server_class=Mock(side_effect=socket.error)) with pytest.raises(CredentialUnavailableError): credential.get_token("scope") +def test_cannot_bind_redirect_uri(): + """When a user specifies a redirect URI, the credential shouldn't attempt to bind another""" + + expected_uri = "http://localhost:42" + + server = Mock(side_effect=socket.error) + credential = InteractiveBrowserCredential(redirect_uri=expected_uri, _server_class=server) + + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") + + server.assert_called_once_with(expected_uri, timeout=ANY) + + def _validate_auth_request_url(url): parsed_url = urllib_parse.urlparse(url) params = urllib_parse.parse_qs(parsed_url.query)