Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add redirect_uri argument to InteractiveBrowserCredential #13480

Merged
merged 1 commit into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
(`azure.identity.aio.CertificateCredential`) will support this in a
future version.
([#10816](https://github.com/Azure/azure-sdk-for-python/issues/10816))
- `InteractiveBrowserCredential` keyword argument `redirect_uri` enables
authentication with a user-specified application having a custom redirect URI
([#13344](https://github.com/Azure/azure-sdk-for-python/issues/13344))

## 1.4.0 (2020-08-10)
### Added
Expand Down
33 changes: 22 additions & 11 deletions sdk/identity/azure-identity/azure/identity/_credentials/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class InteractiveBrowserCredential(InteractiveCredential):
authenticate work or school accounts.
:keyword str client_id: Client ID of the Azure Active Directory application users will sign in to. If
unspecified, the Azure CLI's ID will be used.
:keyword str redirect_uri: a redirect URI for the application identified by `client_id` as configured in Azure
Active Directory, for example "http://localhost:8400". This is only required when passing a value for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a redirect_uri actually required when a client_id is provided? Or would the behavior just be the same as before this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's required if the application hasn't registered "http://localhost" or "http://localhost:8400" as a redirect URI.

`client_id`, and must match a redirect URI in the application's registration. The credential must be able to
bind a socket to this URI.
:keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate`
:keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise
:class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False.
Expand All @@ -48,26 +52,34 @@ class InteractiveBrowserCredential(InteractiveCredential):

def __init__(self, **kwargs):
# type: (**Any) -> None
self._redirect_uri = kwargs.pop("redirect_uri", None)
self._timeout = kwargs.pop("timeout", 300)
self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking
self._server_class = kwargs.pop("_server_class", AuthCodeRedirectServer)
client_id = kwargs.pop("client_id", AZURE_CLI_CLIENT_ID)
super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs)

@wrap_exceptions
def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict

# start an HTTP server on localhost to receive the redirect
redirect_uri = None
for port in range(8400, 9000):
# start an HTTP server to receive the redirect
server = None
redirect_uri = self._redirect_uri
if redirect_uri:
try:
server = self._server_class(port, timeout=self._timeout)
redirect_uri = "http://localhost:{}".format(port)
break
server = self._server_class(redirect_uri, timeout=self._timeout)
except socket.error:
continue # keep looking for an open port

if not redirect_uri:
raise CredentialUnavailableError(message="Couldn't start an HTTP server on " + redirect_uri)
else:
for port in range(8400, 9000):
try:
redirect_uri = "http://localhost:{}".format(port)
server = self._server_class(redirect_uri, timeout=self._timeout)
break
except socket.error:
continue # keep looking for an open port

if not server:
raise CredentialUnavailableError(message="Couldn't start an HTTP server on localhost")

# get the url the user must visit to authenticate
Expand All @@ -93,7 +105,6 @@ def _request_token(self, *scopes, **kwargs):
code = self._parse_response(request_state, response)
return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs)


@staticmethod
def _parse_response(request_state, response):
# type: (str, Mapping[str, Any]) -> List[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Mapping, Optional
from six.moves.urllib_parse import parse_qs, urlparse

try:
from http.server import HTTPServer, BaseHTTPRequestHandler
except ImportError:
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler # type: ignore

try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs # type: ignore
if TYPE_CHECKING:
# pylint:disable=ungrouped-imports
from typing import Any, Mapping


class AuthCodeRedirectHandler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -46,13 +41,14 @@ def log_message(self, format, *args): # pylint: disable=redefined-builtin,unuse


class AuthCodeRedirectServer(HTTPServer):
"""HTTP server that listens on localhost for the redirect request following an authorization code authentication"""
"""HTTP server that listens for the redirect request following an authorization code authentication"""

query_params = {} # type: Mapping[str, Any]

def __init__(self, port, timeout):
# type: (int, int) -> None
HTTPServer.__init__(self, ("localhost", port), AuthCodeRedirectHandler)
def __init__(self, uri, timeout):
# type: (str, int) -> None
parsed = urlparse(uri)
HTTPServer.__init__(self, (parsed.hostname, parsed.port), AuthCodeRedirectHandler)
self.timeout = timeout

def wait_for_redirect(self):
Expand Down
83 changes: 49 additions & 34 deletions sdk/identity/azure-identity/tests/test_browser_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
build_id_token,
get_discovery_response,
mock_response,
msal_validating_transport,
Request,
validating_transport,
)

try:
from unittest.mock import Mock, patch
from unittest.mock import ANY, Mock, patch
except ImportError: # python < 3.3
from mock import Mock, patch # type: ignore
from mock import ANY, Mock, patch # type: ignore


WEBBROWSER_OPEN = InteractiveBrowserCredential.__module__ + ".webbrowser.open"
Expand Down Expand Up @@ -77,7 +78,7 @@ def test_authenticate():
_cache=TokenCache(),
authority=environment,
client_id=client_id,
server_class=server_class,
_server_class=server_class,
tenant_id=tenant_id,
transport=transport,
)
Expand Down Expand Up @@ -126,7 +127,7 @@ def test_policies_configurable():
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
policies=[policy], client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache()
policies=[policy], client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache()
)

with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
Expand All @@ -152,15 +153,16 @@ def test_user_agent():
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache()
client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache()
)

with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
credential.get_token("scope")


@patch("azure.identity._credentials.browser.webbrowser.open")
def test_interactive_credential(mock_open):
@pytest.mark.parametrize("redirect_url", ("https://localhost:8042", None))
def test_interactive_credential(mock_open, redirect_url):
mock_open.side_effect = _validate_auth_request_url
oauth_state = "state"
client_id = "client-id"
Expand All @@ -171,17 +173,15 @@ def test_interactive_credential(mock_open):
tenant_id = "tenant_id"
endpoint = "https://{}/{}".format(authority, tenant_id)

discovery_response = get_discovery_response(endpoint=endpoint)
transport = validating_transport(
requests=[Request(url_substring=endpoint)] * 3
transport = msal_validating_transport(
endpoint="https://{}/{}".format(authority, tenant_id),
requests=[Request(url_substring=endpoint)]
+ [
Request(
authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token}
)
],
responses=[
discovery_response, # instance discovery
discovery_response, # tenant discovery
mock_response(
json_payload=build_aad_response(
access_token=expected_token,
Expand All @@ -203,37 +203,38 @@ def test_interactive_credential(mock_open):
auth_code_response = {"code": "authorization-code", "state": [oauth_state]}
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
authority=authority,
tenant_id=tenant_id,
client_id=client_id,
server_class=server_class,
transport=transport,
instance_discovery=False,
validate_authority=False,
_cache=TokenCache(),
)
args = {
"authority": authority,
"tenant_id": tenant_id,
"client_id": client_id,
"transport": transport,
"_cache": TokenCache(),
"_server_class": server_class,
}
if redirect_url: # avoid passing redirect_url=None
args["redirect_uri"] = redirect_url

credential = InteractiveBrowserCredential(**args)

# The credential's auth code request includes a uuid which must be included in the redirect. Patching to
# set the uuid requires less code here than a proper mock server.
with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
token = credential.get_token("scope")
assert token.token == expected_token
assert mock_open.call_count == 1
assert server_class.call_count == 1

if redirect_url:
server_class.assert_called_once_with(redirect_url, timeout=ANY)

# token should be cached, get_token shouldn't prompt again
token = credential.get_token("scope")
assert token.token == expected_token
assert mock_open.call_count == 1

# As of MSAL 1.0.0, applications build a new client every time they redeem a refresh token.
# Here we patch the private method they use for the sake of test coverage.
# TODO: this will probably break when this MSAL behavior changes
app = credential._get_app()
app._build_client = lambda *_: app.client # pylint:disable=protected-access
now = time.time()
assert server_class.call_count == 1

# expired access token -> credential should use refresh token instead of prompting again
now = time.time()
with patch("time.time", lambda: now + expires_in):
token = credential.get_token("scope")
assert token.token == expected_token
Expand All @@ -259,7 +260,7 @@ def test_interactive_credential_timeout():

credential = InteractiveBrowserCredential(
client_id="guid",
server_class=server_class,
_server_class=server_class,
timeout=timeout,
transport=transport,
instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request
Expand All @@ -277,7 +278,8 @@ def test_redirect_server():
for _ in range(4):
try:
port = random.randint(1024, 65535)
server = AuthCodeRedirectServer(port, timeout=10)
url = "http://127.0.0.1:{}".format(port)
server = AuthCodeRedirectServer(url, timeout=10)
break
except socket.error:
continue # keep looking for an open port
Expand All @@ -293,8 +295,7 @@ def test_redirect_server():
thread.start()

# send a request, verify the server exposes the query
url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec
response = urllib.request.urlopen(url) # nosec
response = urllib.request.urlopen(url + "?{}={}".format(expected_param, expected_value)) # nosec

assert response.code == 200
assert server.query_params[expected_param] == [expected_value]
Expand All @@ -304,7 +305,7 @@ def test_redirect_server():
def test_no_browser():
transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2)
credential = InteractiveBrowserCredential(
client_id="client-id", server_class=Mock(), transport=transport, _cache=TokenCache()
client_id="client-id", _server_class=Mock(), transport=transport, _cache=TokenCache()
)
with pytest.raises(ClientAuthenticationError, match=r".*browser.*"):
credential.get_token("scope")
Expand All @@ -313,11 +314,25 @@ def test_no_browser():
def test_cannot_bind_port():
"""get_token should raise CredentialUnavailableError when the redirect listener can't bind a port"""

credential = InteractiveBrowserCredential(server_class=Mock(side_effect=socket.error))
credential = InteractiveBrowserCredential(_server_class=Mock(side_effect=socket.error))
with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_cannot_bind_redirect_uri():
"""When a user specifies a redirect URI, the credential shouldn't attempt to bind another"""

expected_uri = "http://localhost:42"

server = Mock(side_effect=socket.error)
credential = InteractiveBrowserCredential(redirect_uri=expected_uri, _server_class=server)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")

server.assert_called_once_with(expected_uri, timeout=ANY)


def _validate_auth_request_url(url):
parsed_url = urllib_parse.urlparse(url)
params = urllib_parse.parse_qs(parsed_url.query)
Expand Down