From bf50b5fc7df96bba995281090a4ed38975c07e74 Mon Sep 17 00:00:00 2001 From: joguSD Date: Thu, 24 Oct 2019 17:22:55 -0700 Subject: [PATCH 1/2] Add initial support for SSO credentials This inclues: * aws configure SSO to interatively configure an SSO profile * aws sso login to invoke initialize the authorization flow * aws sso logout to clear cached sessions and temporary credentials --- awscli/customizations/configure/__init__.py | 2 + awscli/customizations/configure/configure.py | 2 + awscli/customizations/configure/sso.py | 315 ++++++++++++++ awscli/customizations/sso/__init__.py | 43 ++ awscli/customizations/sso/login.py | 65 +++ awscli/customizations/sso/logout.py | 97 +++++ awscli/customizations/sso/utils.py | 71 ++++ awscli/formatter.py | 21 +- awscli/handlers.py | 2 + tests/functional/sso/__init__.py | 51 +++ tests/functional/sso/test_login.py | 131 ++++++ tests/functional/sso/test_logout.py | 130 ++++++ .../unit/customizations/configure/__init__.py | 9 + .../configure/test_configure.py | 5 + .../unit/customizations/configure/test_sso.py | 393 ++++++++++++++++++ tests/unit/customizations/sso/__init__.py | 12 + tests/unit/customizations/sso/test_sso.py | 60 +++ tests/unit/customizations/sso/test_utils.py | 123 ++++++ 18 files changed, 1523 insertions(+), 9 deletions(-) create mode 100644 awscli/customizations/configure/sso.py create mode 100644 awscli/customizations/sso/__init__.py create mode 100644 awscli/customizations/sso/login.py create mode 100644 awscli/customizations/sso/logout.py create mode 100644 awscli/customizations/sso/utils.py create mode 100644 tests/functional/sso/__init__.py create mode 100644 tests/functional/sso/test_login.py create mode 100644 tests/functional/sso/test_logout.py create mode 100644 tests/unit/customizations/configure/test_sso.py create mode 100644 tests/unit/customizations/sso/__init__.py create mode 100644 tests/unit/customizations/sso/test_sso.py create mode 100644 tests/unit/customizations/sso/test_utils.py diff --git a/awscli/customizations/configure/__init__.py b/awscli/customizations/configure/__init__.py index ef073e409ff9..6055529251d4 100644 --- a/awscli/customizations/configure/__init__.py +++ b/awscli/customizations/configure/__init__.py @@ -44,6 +44,8 @@ def mask_value(current_value): def profile_to_section(profile_name): """Converts a profile name to a section header to be used in the config.""" + if profile_name == 'default': + return profile_name if any(c in _WHITESPACE for c in profile_name): profile_name = shlex_quote(profile_name) return 'profile %s' % profile_name diff --git a/awscli/customizations/configure/configure.py b/awscli/customizations/configure/configure.py index a4ca2af394e2..ed18fd0fc628 100644 --- a/awscli/customizations/configure/configure.py +++ b/awscli/customizations/configure/configure.py @@ -24,6 +24,7 @@ from awscli.customizations.configure.writer import ConfigFileWriter from awscli.customizations.configure.importer import ConfigureImportCommand from awscli.customizations.configure.listprofiles import ListProfilesCommand +from awscli.customizations.configure.sso import ConfigureSSOCommand from . import mask_value, profile_to_section @@ -78,6 +79,7 @@ class ConfigureCommand(BasicCommand): {'name': 'add-model', 'command_class': AddModelCommand}, {'name': 'import', 'command_class': ConfigureImportCommand}, {'name': 'list-profiles', 'command_class': ListProfilesCommand}, + {'name': 'sso', 'command_class': ConfigureSSOCommand}, ] # If you want to add new values to prompt, update this list here. diff --git a/awscli/customizations/configure/sso.py b/awscli/customizations/configure/sso.py new file mode 100644 index 000000000000..4d2eeb8ac50f --- /dev/null +++ b/awscli/customizations/configure/sso.py @@ -0,0 +1,315 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import logging + +from botocore import UNSIGNED +from botocore.config import Config +from botocore.configprovider import ConstantProvider +from botocore.exceptions import ProfileNotFound +from botocore.utils import is_valid_endpoint_url + +from prompt_toolkit import prompt as ptk_prompt +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.validation import Validator +from prompt_toolkit.validation import ValidationError + +from awscli.customizations.utils import uni_print +from awscli.customizations.commands import BasicCommand +from awscli.customizations.configure import profile_to_section +from awscli.customizations.configure.writer import ConfigFileWriter +from awscli.customizations.wizard.selectmenu import select_menu +from awscli.customizations.sso.utils import do_sso_login +from awscli.formatter import CLI_OUTPUT_FORMATS + + +logger = logging.getLogger(__name__) + + +class StartUrlValidator(Validator): + def __init__(self, default=None): + super(StartUrlValidator, self).__init__() + self._default = default + + def validate(self, document): + # If there's a default, allow an empty prompt + if not document.text and self._default: + return + if not is_valid_endpoint_url(document.text): + index = len(document.text) + raise ValidationError(index, 'Not a valid Start URL') + + +class PTKPrompt(object): + def __init__(self, prompter=None): + if prompter is None: + prompter = ptk_prompt + self._prompter = prompter + + def _create_completer(self, completions): + if completions is None: + completions = [] + if isinstance(completions, dict): + meta_dict = completions + completions = list(meta_dict.keys()) + completer = WordCompleter( + completions, + sentence=True, + meta_dict=meta_dict, + ) + else: + completer = WordCompleter(completions, sentence=True) + return completer + + def get_value(self, current_value, prompt_text='', + completions=None, validator=None): + completer = self._create_completer(completions) + prompt_string = u'{} [{}]: '.format(prompt_text, current_value) + response = self._prompter( + prompt_string, + validator=validator, + validate_while_typing=False, + completer=completer, + complete_while_typing=True, + ) + if not response: + # If the user hits enter, we return the current/default value + response = current_value + return response + + +class ConfigureSSOCommand(BasicCommand): + NAME = 'sso' + SYNOPSIS = ('aws configure sso [--profile profile-name]') + DESCRIPTION = ( + 'The ``aws configure sso`` command interactively prompts for the ' + 'configuration values required to create a profile that sources ' + 'temporary AWS credentials from AWS Single Sign-On. To keep an ' + 'existing value, hit enter when prompted for the value. When you ' + 'are prompted for information, the current value will be displayed in ' + '[brackets]. If the config item has no value, it is displayed as ' + '[None]. When providing the ``--profile`` parameter the named profile ' + 'will be created or updated. When a profile is not explicitly set ' + 'the profile name will be prompted for.' + '\n\nNote: The configuration is saved in the shared configuration ' + 'file. By default, ``~/.aws/config``.' + ) + # TODO: Add CLI parameters to skip prompted values, --start-url, etc. + + def __init__(self, session, prompter=None, selector=None, + config_writer=None, sso_token_cache=None): + super(ConfigureSSOCommand, self).__init__(session) + if prompter is None: + prompter = PTKPrompt() + self._prompter = prompter + if selector is None: + selector = select_menu + self._selector = selector + if config_writer is None: + config_writer = ConfigFileWriter() + self._config_writer = config_writer + self._sso_token_cache = sso_token_cache + + self._new_values = {} + self._original_profile_name = self._session.profile + try: + self._config = self._session.get_scoped_config() + except ProfileNotFound: + self._config = {} + # The profile provided to the CLI as --profile may not exist. + # This means we cannot use the session as is to create clients. + # By overriding the profile provider we ensure that a non-existant + # profile won't cause us to fail to create clients. + # No configuration from the profile is needed for the SSO APIs. + # It might be good to see if we can address this in a better way + # in botocore. + config_store = self._session.get_component('config_store') + config_store.set_config_provider('profile', ConstantProvider(None)) + + def _prompt_for(self, config_name, text, + completions=None, validator_cls=None): + current_value = self._config.get(config_name) + if validator_cls is None: + validator = None + else: + validator = validator_cls(current_value) + new_value = self._prompter.get_value( + current_value, text, + completions=completions, + validator=validator, + ) + if new_value: + self._new_values[config_name] = new_value + return new_value + + def _handle_single_account(self, accounts): + sso_account_id = accounts[0]['accountId'] + single_account_msg = ( + 'The only AWS account available to you is: {}\n' + ) + uni_print(single_account_msg.format(sso_account_id)) + return sso_account_id + + def _display_account(self, account): + return '{accountName}, {emailAddress} ({accountId})'.format(**account) + + def _handle_multiple_accounts(self, accounts): + available_accounts_msg = ( + 'There are {} AWS accounts available to you.\n' + ) + uni_print(available_accounts_msg.format(len(accounts))) + selected_account = self._selector(accounts, self._display_account) + sso_account_id = selected_account['accountId'] + return sso_account_id + + def _get_all_accounts(self, sso, sso_token): + paginator = sso.get_paginator('list_accounts') + results = paginator.paginate(accessToken=sso_token['accessToken']) + return results.build_full_result() + + def _prompt_for_account(self, sso, sso_token): + accounts = self._get_all_accounts(sso, sso_token)['accountList'] + if not accounts: + raise RuntimeError('No AWS accounts are available to you.') + if len(accounts) == 1: + sso_account_id = self._handle_single_account(accounts) + else: + sso_account_id = self._handle_multiple_accounts(accounts) + uni_print('Using the account ID {}\n'.format(sso_account_id)) + self._new_values['sso_account_id'] = sso_account_id + return sso_account_id + + def _handle_single_role(self, roles): + sso_role_name = roles[0]['roleName'] + available_roles_msg = 'The only role available to you is: {}\n' + uni_print(available_roles_msg.format(sso_role_name)) + return sso_role_name + + def _handle_multiple_roles(self, roles): + available_roles_msg = 'There are {} roles available to you.\n' + uni_print(available_roles_msg.format(len(roles))) + role_names = [r['roleName'] for r in roles] + sso_role_name = self._selector(role_names) + return sso_role_name + + def _get_all_roles(self, sso, sso_token, sso_account_id): + paginator = sso.get_paginator('list_account_roles') + results = paginator.paginate( + accountId=sso_account_id, + accessToken=sso_token['accessToken'] + ) + return results.build_full_result() + + def _prompt_for_role(self, sso, sso_token, sso_account_id): + roles = self._get_all_roles(sso, sso_token, sso_account_id)['roleList'] + if not roles: + error_msg = 'No roles are available for the account {}' + raise RuntimeError(error_msg.format(sso_account_id)) + if len(roles) == 1: + sso_role_name = self._handle_single_role(roles) + else: + sso_role_name = self._handle_multiple_roles(roles) + uni_print('Using the role name "{}"\n'.format(sso_role_name)) + self._new_values['sso_role_name'] = sso_role_name + return sso_role_name + + def _prompt_for_profile(self, sso_account_id, sso_role_name): + if self._original_profile_name: + profile_name = self._original_profile_name + else: + default_profile = '{}-{}'.format(sso_role_name, sso_account_id) + text = 'CLI profile name' + profile_name = self._prompter.get_value(default_profile, text) + return profile_name + + def _get_potential_start_urls(self): + profiles = self._session.full_config.get('profiles', []) + potential_start_urls = set() + for profile, config in profiles.items(): + if 'sso_start_url' in config: + start_url = config['sso_start_url'] + potential_start_urls.add(start_url) + return list(potential_start_urls) + + def _prompt_for_start_url(self): + potential_start_urls = self._get_potential_start_urls() + start_url = self._prompt_for( + 'sso_start_url', 'SSO start URL', + completions=potential_start_urls, + validator_cls=StartUrlValidator, + ) + return start_url + + def _get_potential_sso_regions(self): + return self._session.get_available_regions('sso-oidc') + + def _prompt_for_sso_region(self): + potential_sso_regions = self._get_potential_sso_regions() + sso_region = self._prompt_for( + 'sso_region', 'SSO Region', + completions=potential_sso_regions, + ) + return sso_region + + def _prompt_for_cli_default_region(self): + # TODO: figure out a way to get a list of reasonable client regions + return self._prompt_for('region', 'CLI default client Region') + + def _prompt_for_cli_output_format(self): + return self._prompt_for( + 'output', 'CLI default output format', + completions=list(CLI_OUTPUT_FORMATS.keys()), + ) + + def _run_main(self, parsed_args, parsed_globals): + start_url = self._prompt_for_start_url() + sso_region = self._prompt_for_sso_region() + sso_token = do_sso_login( + self._session, + sso_region, + start_url, + token_cache=self._sso_token_cache, + ) + + # Construct an SSO client to explore the accounts / roles + client_config = Config( + signature_version=UNSIGNED, + region_name=sso_region, + ) + sso = self._session.create_client('sso', config=client_config) + + sso_account_id = self._prompt_for_account(sso, sso_token) + sso_role_name = self._prompt_for_role(sso, sso_token, sso_account_id) + + # General CLI configuration + self._prompt_for_cli_default_region() + self._prompt_for_cli_output_format() + + profile_name = self._prompt_for_profile(sso_account_id, sso_role_name) + + usage_msg = ( + '\nTo use this profile, specify the profile name using ' + '--profile, as shown:\n\n' + 'aws s3 ls --profile {}\n' + ) + uni_print(usage_msg.format(profile_name)) + + self._write_new_config(profile_name) + + def _write_new_config(self, profile): + config_path = self._session.get_config_variable('config_file') + config_path = os.path.expanduser(config_path) + if self._new_values: + section = profile_to_section(profile) + self._new_values['__section__'] = section + self._config_writer.update_config(self._new_values, config_path) diff --git a/awscli/customizations/sso/__init__.py b/awscli/customizations/sso/__init__.py new file mode 100644 index 000000000000..b5e2a6cc2219 --- /dev/null +++ b/awscli/customizations/sso/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.exceptions import ProfileNotFound +from botocore.exceptions import UnknownCredentialError +from botocore.credentials import JSONFileCache + +from awscli.customizations.sso.login import LoginCommand +from awscli.customizations.sso.logout import LogoutCommand +from awscli.customizations.sso.utils import AWS_CREDS_CACHE_DIR + + +def register_sso_commands(event_emitter): + event_emitter.register( + 'building-command-table.sso', add_sso_commands, + ) + event_emitter.register( + 'session-initialized', inject_json_file_cache, + unique_id='inject_sso_json_file_cache' + ) + + +def add_sso_commands(command_table, session, **kwargs): + command_table['login'] = LoginCommand(session) + command_table['logout'] = LogoutCommand(session) + + +def inject_json_file_cache(session, **kwargs): + try: + cred_chain = session.get_component('credential_provider') + sso_provider = cred_chain.get_provider('sso') + sso_provider.cache = JSONFileCache(AWS_CREDS_CACHE_DIR) + except (ProfileNotFound, UnknownCredentialError): + return diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py new file mode 100644 index 000000000000..ec3b9cba3ef3 --- /dev/null +++ b/awscli/customizations/sso/login.py @@ -0,0 +1,65 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from awscli.customizations.commands import BasicCommand +from awscli.customizations.sso.utils import do_sso_login +from awscli.customizations.utils import uni_print + + +class InvalidSSOConfigError(Exception): + pass + + +class LoginCommand(BasicCommand): + NAME = 'login' + DESCRIPTION = ( + 'Retrieves and caches an AWS SSO access token to exchange for AWS ' + 'credentials. To login, the requested profile must have first been ' + 'setup using ``aws configure sso``. Each time the ``login`` command ' + 'is called, a new SSO access token will be retrieved.' + ) + ARG_TABLE = [] + _REQUIRED_SSO_CONFIG_VARS = [ + 'sso_start_url', + 'sso_region', + 'sso_role_name', + 'sso_account_id', + ] + + def _run_main(self, parsed_args, parsed_globals): + sso_config = self._get_sso_config() + do_sso_login( + session=self._session, + sso_region=sso_config['sso_region'], + start_url=sso_config['sso_start_url'], + force_refresh=True + ) + success_msg = 'Successully logged into Start URL: %s\n' + uni_print(success_msg % sso_config['sso_start_url']) + return 0 + + def _get_sso_config(self): + scoped_config = self._session.get_scoped_config() + sso_config = {} + missing_vars = [] + for config_var in self._REQUIRED_SSO_CONFIG_VARS: + if config_var not in scoped_config: + missing_vars.append(config_var) + else: + sso_config[config_var] = scoped_config[config_var] + if missing_vars: + raise InvalidSSOConfigError( + 'Missing the following required SSO configuration values: %s. ' + 'To make sure this profile is properly configured to use SSO, ' + 'please run: aws configure sso' % ', '.join(missing_vars) + ) + return sso_config diff --git a/awscli/customizations/sso/logout.py b/awscli/customizations/sso/logout.py new file mode 100644 index 000000000000..2c906e78fd95 --- /dev/null +++ b/awscli/customizations/sso/logout.py @@ -0,0 +1,97 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json +import logging +import os + +from botocore.exceptions import ClientError + +from awscli.customizations.commands import BasicCommand +from awscli.customizations.sso.utils import SSO_TOKEN_DIR +from awscli.customizations.sso.utils import AWS_CREDS_CACHE_DIR + + +LOG = logging.getLogger(__name__) + + +class LogoutCommand(BasicCommand): + NAME = 'logout' + DESCRIPTION = ( + 'Removes all cached AWS SSO access tokens and any cached temporary ' + 'AWS credentials retrieved with SSO access tokens across all ' + 'profiles. To use these profiles again, run: ``aws sso login``' + ) + ARG_TABLE = [] + + def _run_main(self, parsed_args, parsed_globals): + SSOTokenSweeper(self._session).delete_credentials(SSO_TOKEN_DIR) + SSOCredentialSweeper().delete_credentials(AWS_CREDS_CACHE_DIR) + return 0 + + +class BaseCredentialSweeper(object): + def delete_credentials(self, creds_dir): + if not os.path.isdir(creds_dir): + return + filenames = os.listdir(creds_dir) + for filename in filenames: + filepath = os.path.join(creds_dir, filename) + contents = self._get_json_contents(filepath) + if contents is None: + continue + if self._should_delete(contents): + self._before_deletion(contents) + os.remove(filepath) + + def _should_delete(self, filename): + raise NotImplementedError('_should_delete') + + def _get_json_contents(self, filename): + try: + with open(filename, 'r') as f: + return json.load(f) + except Exception: + # We do not want to include the traceback in the exception + # so that we do not accidentally log sensitive contents because + # of the exception or its Traceback. + LOG.debug('Failed to load: %s', filename) + return None + + def _before_deletion(self, contents): + pass + + +class SSOTokenSweeper(BaseCredentialSweeper): + def __init__(self, session): + self._session = session + + def _should_delete(self, contents): + return 'accessToken' in contents + + def _before_deletion(self, contents): + # If the sso region is present in the cached token, construct a client + # and invoke the logout api to invalidate the token before deleting it. + sso_region = contents.get('region') + if sso_region: + sso = self._session.create_client('sso', region_name=sso_region) + try: + sso.logout(accessToken=contents['accessToken']) + except ClientError: + # The token may alread be expired or otherwise invalid. If we + # get a client error on logout just log and continue on + LOG.debug('Failed to call logout API:', exc_info=True) + + +class SSOCredentialSweeper(BaseCredentialSweeper): + def _should_delete(self, contents): + return contents.get('ProviderType') == 'sso' diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py new file mode 100644 index 000000000000..eb4b98057d94 --- /dev/null +++ b/awscli/customizations/sso/utils.py @@ -0,0 +1,71 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import logging +import webbrowser + +from botocore.utils import SSOTokenFetcher +from botocore.credentials import JSONFileCache + +from awscli.customizations.utils import uni_print +from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR + +LOG = logging.getLogger(__name__) + +SSO_TOKEN_DIR = os.path.expanduser( + os.path.join('~', '.aws', 'sso', 'cache') +) + + +def do_sso_login(session, sso_region, start_url, token_cache=None, + on_pending_authorization=None, force_refresh=False): + if token_cache is None: + token_cache = JSONFileCache(SSO_TOKEN_DIR) + if on_pending_authorization is None: + on_pending_authorization = OpenBrowserHandler() + token_fetcher = SSOTokenFetcher( + sso_region=sso_region, + client_creator=session.create_client, + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + return token_fetcher.fetch_token( + start_url=start_url, + force_refresh=force_refresh + ) + + +class OpenBrowserHandler(object): + def __init__(self, outfile=None, open_browser=None): + self._outfile = outfile + if open_browser is None: + open_browser = webbrowser.open_new_tab + self._open_browser = open_browser + + def __call__(self, userCode, verificationUri, + verificationUriComplete, **kwargs): + opening_msg = ( + 'Attempting to automatically open the SSO authorization page in ' + 'your default browser.\nIf the browser does not open or you wish ' + 'to use a different device to authorize this request, open the ' + 'following URL:\n' + '\n%s\n' + '\nThen enter the code:\n' + '\n%s\n' + ) + uni_print(opening_msg % (verificationUri, userCode), self._outfile) + if self._open_browser: + try: + return self._open_browser(verificationUriComplete) + except Exception: + LOG.debug('Failed to open browser:', exc_info=True) diff --git a/awscli/formatter.py b/awscli/formatter.py index 456e85d92285..3f04a5b484c8 100644 --- a/awscli/formatter.py +++ b/awscli/formatter.py @@ -300,13 +300,16 @@ def _format_response(self, response, stream): text.format_text(response, stream) +CLI_OUTPUT_FORMATS = { + 'json': JSONFormatter, + 'text': TextFormatter, + 'table': TableFormatter, + 'yaml': YAMLFormatter, +} + + def get_formatter(format_type, args): - if format_type == 'json': - return JSONFormatter(args) - elif format_type == 'text': - return TextFormatter(args) - elif format_type == 'table': - return TableFormatter(args) - elif format_type == 'yaml': - return YAMLFormatter(args) - raise ValueError("Unknown output type: %s" % format_type) + if format_type not in CLI_OUTPUT_FORMATS: + raise ValueError("Unknown output type: %s" % format_type) + format_type_cls = CLI_OUTPUT_FORMATS[format_type] + return format_type_cls(args) diff --git a/awscli/handlers.py b/awscli/handlers.py index 1171320d5c98..561985466378 100644 --- a/awscli/handlers.py +++ b/awscli/handlers.py @@ -76,6 +76,7 @@ from awscli.customizations.s3errormsg import register_s3_error_msg from awscli.customizations.scalarparse import register_scalar_parser from awscli.customizations.sessendemail import register_ses_send_email +from awscli.customizations.sso import register_sso_commands from awscli.customizations.streamingoutputarg import add_streaming_output_arg from awscli.customizations.translate import register_translate_import_terminology from awscli.customizations.toplevelbool import register_bool_params @@ -178,3 +179,4 @@ def awscli_initialize(event_handlers): register_dev_commands(event_handlers) register_wizard_commands(event_handlers) register_sms_voice_hide(event_handlers) + register_sso_commands(event_handlers) diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py new file mode 100644 index 000000000000..8c77a2db35a6 --- /dev/null +++ b/tests/functional/sso/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from awscli.testutils import mock +from awscli.testutils import create_clidriver +from awscli.testutils import FileCreator +from awscli.testutils import BaseAWSCommandParamsTest + + +class BaseSSOTest(BaseAWSCommandParamsTest): + def setUp(self): + super(BaseSSOTest, self).setUp() + self.files = FileCreator() + self.start_url = 'https://mysigin.com' + self.sso_region = 'us-west-2' + self.account = '012345678912' + self.role_name = 'SSORole' + self.config_file = self.files.full_path('config') + self.environ['AWS_CONFIG_FILE'] = self.config_file + self.set_config_file_content() + self.access_token = 'foo.token.string' + + def tearDown(self): + super(BaseSSOTest, self).tearDown() + self.files.remove_all() + + def set_config_file_content(self, content=None): + if content is None: + content = ( + '[default]\n' + 'sso_start_url=%s\n' + 'sso_region=%s\n' + 'sso_role_name=%s\n' + 'sso_account_id=%s\n' % ( + self.start_url, self.sso_region, self.role_name, + self.account + ) + ) + self.files.create_file(self.config_file, content) + # We need to recreate the driver (which includes its session) in order + # for the config changes to be pulled in by the session. + self.driver = create_clidriver() diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py new file mode 100644 index 000000000000..abc825c56098 --- /dev/null +++ b/tests/functional/sso/test_login.py @@ -0,0 +1,131 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import hashlib +import json +import os +import time + +from awscli.testutils import mock +from tests.functional.sso import BaseSSOTest + + +class TestLoginCommand(BaseSSOTest): + def setUp(self): + super(TestLoginCommand, self).setUp() + self.token_cache_dir = self.files.full_path('token-cache') + self.token_cache_dir_patch = mock.patch( + 'awscli.customizations.sso.utils.SSO_TOKEN_DIR', + self.token_cache_dir + ) + self.token_cache_dir_patch.start() + + def tearDown(self): + super(TestLoginCommand, self).tearDown() + self.token_cache_dir_patch.stop() + + def add_oidc_workflow_responses(self, access_token, + include_register_response=True): + responses = [ + # StartDeviceAuthorization response + { + 'interval': 1, + 'expiresIn': 600, + 'userCode': 'foo', + 'deviceCode': 'foo-device-code', + 'verificationUri': 'https://sso.fake/device', + 'verificationUriComplete': 'https://sso.verify', + }, + # CreateToken response + { + 'expiresIn': 28800, + 'tokenType': 'Bearer', + 'accessToken': access_token, + } + ] + if include_register_response: + responses.insert( + 0, + { + 'clientSecretExpiresAt': time.time() + 1000, + 'clientId': 'foo-client-id', + 'clientSecret': 'foo-client-secret', + } + ) + self.parsed_responses = responses + + def assert_used_expected_sso_region(self, expected_region): + self.assertIn(expected_region, self.last_request_dict['url']) + + def assert_cache_contains_token(self, start_url, expected_token): + cached_files = os.listdir(self.token_cache_dir) + # The registration and cached access token + self.assertEqual(len(cached_files), 2) + cached_token_filename = self._get_cached_token_filename(start_url) + self.assertIn(cached_token_filename, cached_files) + self.assertEqual( + self._get_token(cached_token_filename), + expected_token + ) + + def _get_cached_token_filename(self, start_url): + return hashlib.sha1(start_url.encode('utf-8')).hexdigest() + '.json' + + def _get_token(self, token_filename): + token_path = os.path.join(self.token_cache_dir, token_filename) + with open(token_path, 'r') as f: + cached_response = json.loads(f.read()) + return cached_response['accessToken'] + + def test_login(self): + self.add_oidc_workflow_responses(self.access_token) + self.run_cmd('sso login') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token + ) + + def test_login_forces_refresh(self): + self.add_oidc_workflow_responses(self.access_token) + self.run_cmd('sso login') + # The register response from the first login should have been + # cached. + self.add_oidc_workflow_responses( + 'new.token', include_register_response=False) + self.run_cmd('sso login') + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token' + ) + + def test_login_no_sso_configuration(self): + self.set_config_file_content(content='') + _, stderr, _ = self.run_cmd('sso login', expected_rc=255) + self.assertIn( + 'Missing the following required SSO configuration', + stderr + ) + + def test_login_partially_missing_sso_configuration(self): + content = ( + '[default]\n' + 'sso_start_url=%s\n' % self.start_url + ) + self.set_config_file_content(content=content) + _, stderr, _ = self.run_cmd('sso login', expected_rc=255) + self.assertIn( + 'Missing the following required SSO configuration', + stderr + ) + self.assertIn('sso_region', stderr) + self.assertNotIn('sso_start_url', stderr) diff --git a/tests/functional/sso/test_logout.py b/tests/functional/sso/test_logout.py new file mode 100644 index 000000000000..b25128a9b397 --- /dev/null +++ b/tests/functional/sso/test_logout.py @@ -0,0 +1,130 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json +import os + +from awscli.testutils import mock +from tests.functional.sso import BaseSSOTest + + +class TestLogoutCommand(BaseSSOTest): + def setUp(self): + super(TestLogoutCommand, self).setUp() + self.token_cache_dir = self.files.full_path('token-cache') + self.token_cache_dir_patch = mock.patch( + 'awscli.customizations.sso.logout.SSO_TOKEN_DIR', + self.token_cache_dir + ) + self.token_cache_dir_patch.start() + self.aws_creds_cache_dir = self.files.full_path('aws-creds-cache') + self.aws_creds_cache_dir_patch = mock.patch( + 'awscli.customizations.sso.logout.AWS_CREDS_CACHE_DIR', + self.aws_creds_cache_dir + ) + self.aws_creds_cache_dir_patch.start() + + def tearDown(self): + super(TestLogoutCommand, self).tearDown() + self.token_cache_dir_patch.stop() + self.aws_creds_cache_dir_patch.stop() + + def add_cached_token(self, filename): + token_path = os.path.join(self.token_cache_dir, filename) + token_contents = { + 'region': self.sso_region, + 'startUrl': self.start_url, + 'expiresAt': '2019-10-26T05:19:09UTC', + 'accessToken': self.access_token, + } + self.files.create_file(token_path, json.dumps(token_contents)) + return token_path + + def add_cached_aws_credentials(self, filename, from_sso=True): + creds_contents = { + 'Credentials': { + 'AccessKeyId': 'access-key', + 'SecretAccessKey': 'secret-key', + "SessionToken": 'session-token', + "Expiration": '2020-01-23T20:48:59UTC' + }, + } + if from_sso: + creds_contents['ProviderType'] = 'sso' + creds_path = os.path.join(self.aws_creds_cache_dir, filename) + self.files.create_file(creds_path, json.dumps(creds_contents)) + return creds_path + + def assert_file_exists(self, filename): + self.assertTrue(os.path.exists(filename)) + + def assert_file_does_not_exist(self, filename): + self.assertFalse(os.path.exists(filename)) + + def test_logout(self): + token = self.add_cached_token('token.json') + creds = self.add_cached_aws_credentials('sso-creds.json') + self.run_cmd('sso logout') + self.assert_file_does_not_exist(token) + self.assert_file_does_not_exist(creds) + + def test_logout_multiple_cached_files(self): + token = self.add_cached_token('token.json') + token2 = self.add_cached_token('token2.json') + creds = self.add_cached_aws_credentials('sso-creds.json') + creds2 = self.add_cached_aws_credentials('sso-creds2.json') + self.run_cmd('sso logout') + self.assert_file_does_not_exist(token) + self.assert_file_does_not_exist(token2) + self.assert_file_does_not_exist(creds) + self.assert_file_does_not_exist(creds2) + + def test_logout_ignores_non_sso_tokens(self): + registration_token = os.path.join( + self.token_cache_dir, 'botocore-client-id.json') + self.files.create_file( + registration_token, json.dumps({'clientId': 'myid'})) + self.run_cmd('sso logout') + self.assert_file_exists(registration_token) + + def test_logout_ignores_non_sso_retrieved_aws_creds(self): + creds = self.add_cached_aws_credentials('creds.json', from_sso=False) + self.run_cmd('sso logout') + self.assert_file_exists(creds) + + def test_ignores_invalid_json_files(self): + invalid_json = os.path.join(self.token_cache_dir, 'invalid.json') + self.files.create_file(invalid_json, '{not-json') + self.run_cmd('sso logout') + self.assert_file_exists(invalid_json) + + def test_does_not_fail_when_cache_does_not_exist(self): + self.assertFalse(os.path.exists(self.token_cache_dir)) + self.assertFalse(os.path.exists(self.aws_creds_cache_dir)) + self.run_cmd('sso logout', expected_rc=0) + + def test_calls_sso_logout_with_token(self): + token = self.add_cached_token('token.json') + self.run_cmd('sso logout') + self.assert_file_does_not_exist(token) + self.assertEqual(len(self.operations_called), 1) + self.assertEqual(self.operations_called[0][0].name, 'Logout') + expected_logout_params = { + 'accessToken': self.access_token, + } + self.assertEqual(self.operations_called[0][1], expected_logout_params) + + def test_calls_sso_logout_and_handles_client_error(self): + self.http_response.status_code = 400 + token = self.add_cached_token('token.json') + self.run_cmd('sso logout', expected_rc=0) + self.assert_file_does_not_exist(token) diff --git a/tests/unit/customizations/configure/__init__.py b/tests/unit/customizations/configure/__init__.py index 64a6287fcd34..7c3102ff48a9 100644 --- a/tests/unit/customizations/configure/__init__.py +++ b/tests/unit/customizations/configure/__init__.py @@ -13,6 +13,11 @@ from botocore.exceptions import ProfileNotFound +class FakeConfigStore(object): + def set_config_provider(self, *args, **kwargs): + pass + + class FakeSession(object): def __init__(self, all_variables, profile_does_not_exist=False, @@ -45,6 +50,10 @@ def get_scoped_config(self): raise ProfileNotFound(profile='foo') return self.config + def get_component(self, component, *args, **kwargs): + if component == 'config_store': + return FakeConfigStore() + def get_config_variable(self, name, methods=None): if name == 'credentials_file': # The credentials_file var doesn't require a diff --git a/tests/unit/customizations/configure/test_configure.py b/tests/unit/customizations/configure/test_configure.py index cb244e2c93fe..e60cdcfff5f9 100644 --- a/tests/unit/customizations/configure/test_configure.py +++ b/tests/unit/customizations/configure/test_configure.py @@ -261,6 +261,11 @@ def test_profile_with_consecutive_spaces(self): section = profile_to_section(profile) self.assertEqual('profile \' \'', section) + def test_default_profile(self): + profile = 'default' + section = profile_to_section(profile) + self.assertEqual(profile, section) + class PrecannedPrompter(object): diff --git a/tests/unit/customizations/configure/test_sso.py b/tests/unit/customizations/configure/test_sso.py new file mode 100644 index 000000000000..3babaff92d98 --- /dev/null +++ b/tests/unit/customizations/configure/test_sso.py @@ -0,0 +1,393 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import mock +from datetime import datetime, timedelta +from dateutil.tz import tzlocal + +from prompt_toolkit import prompt as ptk_prompt +from prompt_toolkit.document import Document +from prompt_toolkit.validation import DummyValidator +from prompt_toolkit.validation import ValidationError + +from botocore.session import Session +from botocore.stub import Stubber +from botocore.exceptions import ProfileNotFound + +from awscli.testutils import unittest +from awscli.customizations.configure.sso import select_menu +from awscli.customizations.configure.sso import PTKPrompt +from awscli.customizations.configure.sso import ConfigureSSOCommand +from awscli.customizations.configure.sso import StartUrlValidator +from awscli.customizations.configure.writer import ConfigFileWriter +from awscli.formatter import CLI_OUTPUT_FORMATS + + +class TestPTKPrompt(unittest.TestCase): + def setUp(self): + self.mock_prompter = mock.Mock(spec=ptk_prompt) + self.prompter = PTKPrompt(prompter=self.mock_prompter) + + def test_returns_input(self): + self.mock_prompter.return_value = 'new_value' + response = self.prompter.get_value('default_value', 'Prompt Text') + self.assertEqual(response, 'new_value') + + def test_user_hits_enter_returns_current(self): + self.mock_prompter.return_value = '' + response = self.prompter.get_value('default_value', 'Prompt Text') + # We convert the empty string to the default value + self.assertEqual(response, 'default_value') + + def assert_expected_completions(self, completions): + _, kwargs = self.mock_prompter.call_args_list[0] + self.assertEqual(kwargs['completer'].words, completions) + + def assert_expected_meta_dict(self, meta_dict): + _, kwargs = self.mock_prompter.call_args_list[0] + self.assertEqual(kwargs['completer'].meta_dict, meta_dict) + + def assert_expected_validator(self, validator): + _, kwargs = self.mock_prompter.call_args_list[0] + self.assertEqual(kwargs['validator'], validator) + + def test_handles_list_completions(self): + completions = ['a', 'b'] + self.prompter.get_value('', '', completions=completions) + self.assert_expected_completions(completions) + + def test_handles_dict_completions(self): + descriptions = { + 'a': 'the letter a', + 'b': 'the letter b', + } + expected_completions = ['a', 'b'] + self.prompter.get_value('', '', completions=descriptions) + self.assert_expected_completions(expected_completions) + self.assert_expected_meta_dict(descriptions) + + def test_passes_validator(self): + validator = DummyValidator() + self.prompter.get_value('', '', validator=validator) + self.assert_expected_validator(validator) + + +class TestStartUrlValidator(unittest.TestCase): + def setUp(self): + self.document = mock.Mock(spec=Document) + self.validator = StartUrlValidator() + + def _validate_text(self, text): + self.document.text = text + self.validator.validate(self.document) + + def assert_text_not_allowed(self, text): + with self.assertRaises(ValidationError): + self._validate_text(text) + + def test_disallowed_text(self): + not_start_urls = [ + '', + 'd-abc123', + 'foo bar baz', + ] + for text in not_start_urls: + self.assert_text_not_allowed(text) + + def test_allowed_text(self): + valid_start_urls = [ + 'https://d-abc123.awsapps.com/start', + 'https://d-abc123.awsapps.com/start#', + 'https://d-abc123.awsapps.com/start/', + 'https://d-abc123.awsapps.com/start-beta', + 'https://start.url', + ] + for text in valid_start_urls: + self._validate_text(text) + + def test_allows_empty_string_if_default(self): + default = 'https://some.default' + self.validator = StartUrlValidator(default) + self._validate_text('') + + +class TestConfigureSSOCommand(unittest.TestCase): + def setUp(self): + self.global_args = mock.Mock() + self._session = Session() + self.sso_client = self._session.create_client('sso') + self.sso_stub = Stubber(self.sso_client) + self.profile = 'a-profile' + self.scoped_config = {} + self.full_config = { + 'profiles': { + self.profile: self.scoped_config + } + } + self.mock_session = mock.Mock(spec=Session) + self.mock_session.get_scoped_config.return_value = self.scoped_config + self.mock_session.full_config = self.full_config + self.mock_session.create_client.return_value = self.sso_client + self.mock_session.profile = self.profile + self.config_path = '/some/path' + self.session_config = { + 'config_file': self.config_path, + } + self.mock_session.get_config_variable = self.session_config.get + self.mock_session.get_available_regions.return_value = ['us-east-1'] + self.token_cache = {} + self.writer = mock.Mock(spec=ConfigFileWriter) + self.prompter = mock.Mock(spec=PTKPrompt) + self.selector = mock.Mock(spec=select_menu) + self.configure_sso = ConfigureSSOCommand( + self.mock_session, + prompter=self.prompter, + selector=self.selector, + config_writer=self.writer, + sso_token_cache=self.token_cache, + ) + self.region = 'us-west-2' + self.output = 'json' + self.sso_region = 'us-east-1' + self.start_url = 'https://d-92671207e4.awsapps.com/start' + self.account_id = '0123456789' + self.role_name = 'roleA' + self.cached_token_key = '13f9d35043871d073ab260e020f0ffde092cb14b' + self.expires_at = datetime.now(tzlocal()) + timedelta(hours=24) + self.access_token = { + 'accessToken': 'access.token.string', + 'expiresAt': self.expires_at, + } + self.token_cache[self.cached_token_key] = self.access_token + + def _add_list_accounts_response(self, accounts): + params = { + 'accessToken': self.access_token['accessToken'], + } + response = { + 'accountList': accounts, + } + self.sso_stub.add_response('list_accounts', response, params) + + def _add_list_account_roles_response(self, roles): + params = { + 'accountId': self.account_id, + 'accessToken': self.access_token['accessToken'], + } + response = { + 'roleList': roles, + } + self.sso_stub.add_response('list_account_roles', response, params) + + def _add_prompt_responses(self): + self.prompter.get_value.side_effect = [ + self.start_url, + self.sso_region, + self.region, + self.output, + ] + + def _add_simple_single_item_responses(self): + selected_account = { + 'accountId': self.account_id, + 'emailAddress': 'account@site.com', + } + self._add_list_accounts_response([selected_account]) + self._add_list_account_roles_response([{'roleName': self.role_name}]) + + def assert_config_updates(self, config=None): + if config is None: + config = { + '__section__': 'profile %s' % self.profile, + 'sso_start_url': self.start_url, + 'sso_region': self.sso_region, + 'sso_account_id': self.account_id, + 'sso_role_name': self.role_name, + 'region': self.region, + 'output': self.output, + } + self.writer.update_config.assert_called_with(config, self.config_path) + + def test_basic_configure_sso_flow(self): + self._add_prompt_responses() + selected_account = { + 'accountId': self.account_id, + 'emailAddress': 'account@site.com', + } + self.selector.side_effect = [ + selected_account, + self.role_name, + ] + accounts = [ + selected_account, + {'accountId': '1234567890', 'emailAddress': 'account2@site.com'}, + ] + self._add_list_accounts_response(accounts) + roles = [ + {'roleName': self.role_name}, + {'roleName': 'roleB'}, + ] + self._add_list_account_roles_response(roles) + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + self.assert_config_updates() + + def test_single_account_single_role_flow(self): + self._add_prompt_responses() + self._add_simple_single_item_responses() + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + self.assert_config_updates() + # Account / Role should be auto selected if only one is returned + self.assertEqual(self.selector.call_count, 0) + + def test_no_accounts_flow_raises_error(self): + self.prompter.get_value.side_effect = [self.start_url, self.sso_region] + self._add_list_accounts_response([]) + with self.assertRaises(RuntimeError): + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + + def test_no_roles_flow_raises_error(self): + self._add_prompt_responses() + selected_account = { + 'accountId': self.account_id, + 'emailAddress': 'account@site.com', + } + self._add_list_accounts_response([selected_account]) + self._add_list_account_roles_response([]) + with self.assertRaises(RuntimeError): + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + + def assert_default_prompt_args(self, defaults): + calls = self.prompter.get_value.call_args_list + self.assertEqual(len(calls), len(defaults)) + for call, default in zip(calls, defaults): + # The default to the prompt call is the first positional param + self.assertEqual(call[0][0], default) + + def assert_prompt_completions(self, completions): + calls = self.prompter.get_value.call_args_list + self.assertEqual(len(calls), len(completions)) + for call, completions in zip(calls, completions): + _, kwargs = call + self.assertEqual(kwargs['completions'], completions) + + def test_defaults_to_scoped_config(self): + self.scoped_config['sso_start_url'] = 'default-url' + self.scoped_config['sso_region'] = 'default-sso-region' + self.scoped_config['region'] = 'default-region' + self.scoped_config['output'] = 'default-output' + self._add_prompt_responses() + self._add_simple_single_item_responses() + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + self.assert_config_updates() + expected_defaults = [ + 'default-url', + 'default-sso-region', + 'default-region', + 'default-output', + ] + self.assert_default_prompt_args(expected_defaults) + + def test_handles_no_profile(self): + expected_profile = 'profile-a' + self.profile = None + self.mock_session.profile = None + self.configure_sso = ConfigureSSOCommand( + self.mock_session, + prompter=self.prompter, + selector=self.selector, + config_writer=self.writer, + sso_token_cache=self.token_cache, + ) + # If there is no profile, it will be prompted for as the last value + self.prompter.get_value.side_effect = [ + self.start_url, + self.sso_region, + self.region, + self.output, + expected_profile, + ] + self._add_simple_single_item_responses() + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + self.profile = expected_profile + self.assert_config_updates() + + def test_handles_non_existant_profile(self): + not_found_exception = ProfileNotFound(profile=self.profile) + self.mock_session.get_scoped_config.side_effect = not_found_exception + self.configure_sso = ConfigureSSOCommand( + self.mock_session, + prompter=self.prompter, + selector=self.selector, + config_writer=self.writer, + sso_token_cache=self.token_cache, + ) + self._add_prompt_responses() + self._add_simple_single_item_responses() + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + self.assert_config_updates() + + def test_cli_config_is_none_not_written(self): + self.prompter.get_value.side_effect = [ + self.start_url, + self.sso_region, + # The CLI region and output format shouldn't be written + # to the config as they are None + None, + None + ] + self._add_simple_single_item_responses() + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + expected_config = { + '__section__': 'profile %s' % self.profile, + 'sso_start_url': self.start_url, + 'sso_region': self.sso_region, + 'sso_account_id': self.account_id, + 'sso_role_name': self.role_name, + } + self.assert_config_updates(config=expected_config) + + def test_prompts_suggest_values(self): + self.full_config['profiles']['another_profile'] = { + 'sso_start_url': self.start_url, + } + self._add_prompt_responses() + self._add_simple_single_item_responses() + with self.sso_stub: + self.configure_sso(args=[], parsed_globals=self.global_args) + self.sso_stub.assert_no_pending_responses() + expected_start_urls = [self.start_url] + expected_sso_regions = ['us-east-1'] + expected_cli_regions = None + expected_cli_outputs = list(CLI_OUTPUT_FORMATS.keys()) + expected_completions = [ + expected_start_urls, + expected_sso_regions, + expected_cli_regions, + expected_cli_outputs, + ] + self.assert_prompt_completions(expected_completions) diff --git a/tests/unit/customizations/sso/__init__.py b/tests/unit/customizations/sso/__init__.py new file mode 100644 index 000000000000..cf81c339a106 --- /dev/null +++ b/tests/unit/customizations/sso/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/tests/unit/customizations/sso/test_sso.py b/tests/unit/customizations/sso/test_sso.py new file mode 100644 index 000000000000..f170a7f28dde --- /dev/null +++ b/tests/unit/customizations/sso/test_sso.py @@ -0,0 +1,60 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import botocore.credentials +import botocore.session +from botocore.exceptions import ProfileNotFound, UnknownCredentialError + +from awscli.testutils import mock +from awscli.testutils import unittest +from awscli.customizations.sso import inject_json_file_cache + + +class TestInjectJSONFileCache(unittest.TestCase): + def setUp(self): + self.mock_sso_provider = mock.Mock( + spec=botocore.credentials.SSOProvider) + self.mock_resolver = mock.Mock(botocore.credentials.CredentialResolver) + self.mock_resolver.get_provider.return_value = self.mock_sso_provider + self.session = mock.Mock(spec=botocore.session.Session) + self.session.get_component.return_value = self.mock_resolver + + def test_inject_json_file_cache(self): + inject_json_file_cache( + self.session, event_name='session-initialized' + ) + self.session.get_component.assert_called_with('credential_provider') + self.assertIsInstance( + self.mock_sso_provider.cache, + botocore.credentials.JSONFileCache, + ) + + def test_profile_not_found_is_not_propagated(self): + self.session.get_component.side_effect = ProfileNotFound( + profile='unknown') + try: + inject_json_file_cache( + self.session, event_name='session-initialized' + ) + except ProfileNotFound: + self.fail('ProfileNotFound should not have been raised.') + + def test_provider_not_found_error_is_not_propagated(self): + self.mock_resolver.get_provider.side_effect = UnknownCredentialError( + name='sso' + ) + try: + inject_json_file_cache( + self.session, event_name='session-initialized' + ) + except UnknownCredentialError: + self.fail('UnknownCredentialError should not have been raised.') diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py new file mode 100644 index 000000000000..c8d14d6eb355 --- /dev/null +++ b/tests/unit/customizations/sso/test_utils.py @@ -0,0 +1,123 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import webbrowser +from awscli.testutils import mock +from awscli.testutils import unittest + +from botocore.session import Session + +from awscli.compat import StringIO +from awscli.customizations.sso.utils import do_sso_login +from awscli.customizations.sso.utils import OpenBrowserHandler + + +class TestDoSSOLogin(unittest.TestCase): + def setUp(self): + self.region = 'us-west-2' + self.start_url = 'https://mystarturl.com' + self.token_cache = {} + self.on_pending_authorization_mock = mock.Mock() + self.session = mock.Mock(Session) + self.oidc_client = self.get_mock_sso_oidc_client() + self.session.create_client.return_value = self.oidc_client + + def get_mock_sso_oidc_client(self): + client = mock.Mock() + client.register_client.return_value = { + 'clientSecretExpiresAt': 1000, + 'clientId': 'foo-client-id', + 'clientSecret': 'foo-client-secret', + } + client.start_device_authorization.return_value = { + 'interval': 1, + 'expiresIn': 600, + 'userCode': 'foo', + 'deviceCode': 'foo-device-code', + 'verificationUri': 'https://sso.fake/device', + 'verificationUriComplete': 'https://sso.verify', + } + client.create_token.return_value = { + 'expiresIn': 28800, + 'tokenType': 'Bearer', + 'accessToken': 'access.token', + } + return client + + def assert_client_called_with_start_url(self): + call_args = self.oidc_client.start_device_authorization.call_args + self.assertEqual(call_args[1]['startUrl'], self.start_url) + + def assert_used_sso_region(self): + config = self.session.create_client.call_args[1]['config'] + self.assertEqual(config.region_name, self.region) + + def assert_token_cache_was_filled(self): + self.assertGreater(len(self.token_cache), 0) + + def assert_on_pending_authorization_called(self): + self.assertEqual( + len(self.on_pending_authorization_mock.call_args_list), 1) + + def test_do_sso_login(self): + do_sso_login( + session=self.session, sso_region=self.region, + start_url=self.start_url, token_cache=self.token_cache, + on_pending_authorization=self.on_pending_authorization_mock + ) + # We just want to make some quick checks to make sure all of the + # parameters were plumbed in correctly. + self.assert_client_called_with_start_url() + self.assert_used_sso_region() + self.assert_token_cache_was_filled() + self.assert_on_pending_authorization_called() + + +class TestOpenBrowserHandler(unittest.TestCase): + def setUp(self): + self.stream = StringIO() + self.user_code = '12345' + self.verification_uri = 'https://verification.com' + self.verification_uri_complete = 'https://verification.com?code=12345' + self.pending_authorization = { + 'userCode': self.user_code, + 'verificationUri': self.verification_uri, + 'verificationUriComplete': self.verification_uri_complete, + } + self.open_browser = mock.Mock(spec=webbrowser.open_new_tab) + self.handler = OpenBrowserHandler( + self.stream, + open_browser=self.open_browser, + ) + + def assert_text_in_output(self, *args): + output = self.stream.getvalue() + for text in args: + self.assertIn(text, output) + + def test_call_no_browser(self): + handler = OpenBrowserHandler(self.stream, open_browser=False) + handler(**self.pending_authorization) + self.assert_text_in_output(self.user_code, self.verification_uri) + + def test_call_browser_success(self): + self.handler(**self.pending_authorization) + self.open_browser.assert_called_with(self.verification_uri_complete) + self.assert_text_in_output('automatically', 'open') + # assert the URI and user coe are still displayed + self.assert_text_in_output(self.user_code, self.verification_uri) + + def test_call_browser_fails(self): + self.open_browser.side_effect = webbrowser.Error() + self.handler(**self.pending_authorization) + self.assert_text_in_output(self.user_code, self.verification_uri) + self.open_browser.assert_called_with(self.verification_uri_complete) From bbe43a25889561543801fe1b8efa21d8d4a042a8 Mon Sep 17 00:00:00 2001 From: Jordan Guymon Date: Thu, 7 Nov 2019 13:33:42 -0800 Subject: [PATCH 2/2] Add region name to client for sso unit test --- tests/functional/sso/test_login.py | 8 ++++++++ tests/unit/customizations/configure/test_sso.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index abc825c56098..f41a53f9e992 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -17,6 +17,7 @@ from awscli.testutils import mock from tests.functional.sso import BaseSSOTest +from awscli.customizations.sso.utils import OpenBrowserHandler class TestLoginCommand(BaseSSOTest): @@ -28,10 +29,17 @@ def setUp(self): self.token_cache_dir ) self.token_cache_dir_patch.start() + self.open_browser_mock = mock.Mock(spec=OpenBrowserHandler) + self.open_browser_patch = mock.patch( + 'awscli.customizations.sso.utils.OpenBrowserHandler', + self.open_browser_mock, + ) + self.open_browser_patch.start() def tearDown(self): super(TestLoginCommand, self).tearDown() self.token_cache_dir_patch.stop() + self.open_browser_patch.stop() def add_oidc_workflow_responses(self, access_token, include_register_response=True): diff --git a/tests/unit/customizations/configure/test_sso.py b/tests/unit/customizations/configure/test_sso.py index 3babaff92d98..c4971adabc97 100644 --- a/tests/unit/customizations/configure/test_sso.py +++ b/tests/unit/customizations/configure/test_sso.py @@ -124,7 +124,10 @@ class TestConfigureSSOCommand(unittest.TestCase): def setUp(self): self.global_args = mock.Mock() self._session = Session() - self.sso_client = self._session.create_client('sso') + self.sso_client = self._session.create_client( + 'sso', + region_name='us-west-2', + ) self.sso_stub = Stubber(self.sso_client) self.profile = 'a-profile' self.scoped_config = {}