diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 688825d5019..0ab8abbe59d 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,5 +1,22 @@ Release History =============== +0.2.2 +----- +* Validate that target machine exists before attempting to connect. +* ssh config accepts relative path for --file. +* Make --local-user mandatory for Windows target machines. +* For ssh config, relay information is stored under az_ssh_config folder. +* New optional parameter --arc-proxy-folder to determine where arc proxy is stored. +* Relay information lifetime is synced with certificate lifetime for AAD login. + +0.2.1 +----- +* SSHArc Private Preview 2 + +0.2.0 +----- +* SSHArc Private Preview 1 + 0.1.9 ----- * Add support for connecting to Arc Servers using AAD issued certificates. diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index a01c787d61f..f4c0856090f 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -27,6 +27,9 @@ def load_arguments(self, _): help=('This is an internal argument. This argument is used by Azure Portal to provide a one click ' 'SSH login experience in Cloud shell.'), deprecate_info=c.deprecate(hide=True), action='store_true') + c.argument('ssh_proxy_folder', options_list=['--ssh-proxy-folder'], + help=('Path to the folder where the ssh proxy should be saved. ' + 'Default to .clientsshproxy folder in user\'s home directory if not provided.')) c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') with self.argument_context('ssh config') as c: @@ -47,6 +50,9 @@ def load_arguments(self, _): c.argument('resource_type', options_list=['--resource-type'], help='Resource type should be either Microsoft.Compute or Microsoft.HybridCompute') c.argument('cert_file', options_list=['--certificate-file', '-c'], help='Path to certificate file') + c.argument('ssh_proxy_folder', options_list=['--ssh-proxy-folder'], + help=('Path to the folder where the ssh proxy should be saved. ' + 'Default to .clientsshproxy folder in user\'s home directory if not provided.')) with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], @@ -69,4 +75,7 @@ def load_arguments(self, _): help=('This is an internal argument. This argument is used by Azure Portal to provide a one click ' 'SSH login experience in Cloud shell.'), deprecate_info=c.deprecate(hide=True), action='store_true') + c.argument('ssh_proxy_folder', options_list=['--ssh-proxy-folder'], + help=('Path to the folder where the ssh proxy should be saved. ' + 'Default to .clientsshproxy folder in user\'s home directory if not provided.')) c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') diff --git a/src/ssh/azext_ssh/connectivity_utils.py b/src/ssh/azext_ssh/connectivity_utils.py new file mode 100644 index 00000000000..1b6402a1c66 --- /dev/null +++ b/src/ssh/azext_ssh/connectivity_utils.py @@ -0,0 +1,152 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import time +import stat +import os +import urllib.request +import json +import base64 +from glob import glob + +from azure.cli.core import telemetry +from azure.cli.core import azclierror +from knack import log + +from . import file_utils +from . import constants as consts + +logger = log.get_logger(__name__) + + +# Get the Access Details to connect to Arc Connectivity platform from the HybridConnectivity RP +def get_relay_information(cmd, resource_group, vm_name, certificate_validity_in_seconds): + from azext_ssh._client_factory import cf_endpoint + client = cf_endpoint(cmd.cli_ctx) + + if not certificate_validity_in_seconds or \ + certificate_validity_in_seconds > consts.RELAY_INFO_MAXIMUM_DURATION_IN_SECONDS: + certificate_validity_in_seconds = consts.RELAY_INFO_MAXIMUM_DURATION_IN_SECONDS + + try: + t0 = time.time() + result = client.list_credentials(resource_group_name=resource_group, machine_name=vm_name, + endpoint_name="default", expiresin=certificate_validity_in_seconds) + time_elapsed = time.time() - t0 + telemetry.add_extension_event('ssh', {'Context.Default.AzureCLI.SSHListCredentialsTime': time_elapsed}) + except Exception as e: + telemetry.set_exception(exception='Call to listCredentials failed', + fault_type=consts.LIST_CREDENTIALS_FAILED_FAULT_TYPE, + summary=f'listCredentials failed with error: {str(e)}.') + raise azclierror.ClientRequestError(f"Request for Azure Relay Information Failed: {str(e)}") + + return result + + +# Downloads client side proxy to connect to Arc Connectivity Platform +def get_client_side_proxy(arc_proxy_folder): + + request_uri, install_location, older_version_location = _get_proxy_filename_and_url(arc_proxy_folder) + install_dir = os.path.dirname(install_location) + + # Only download new proxy if it doesn't exist already + if not os.path.isfile(install_location): + t0 = time.time() + # download the executable + try: + with urllib.request.urlopen(request_uri) as response: + response_content = response.read() + response.close() + except Exception as e: + telemetry.set_exception(exception=e, fault_type=consts.PROXY_DOWNLOAD_FAILED_FAULT_TYPE, + summary=f'Failed to download proxy from {request_uri}') + raise azclierror.ClientRequestError(f"Failed to download client proxy executable from {request_uri}. " + "Error: " + str(e)) from e + time_elapsed = time.time() - t0 + + proxy_data = { + 'Context.Default.AzureCLI.SSHProxyDownloadTime': time_elapsed, + 'Context.Default.AzureCLI.SSHProxyVersion': consts.CLIENT_PROXY_VERSION + } + telemetry.add_extension_event('ssh', proxy_data) + + # if directory doesn't exist, create it + if not os.path.isdir(install_dir): + file_utils.create_directory(install_dir, f"Failed to create client proxy directory '{install_dir}'. ") + # if directory exists, delete any older versions of the proxy + else: + older_version_files = glob(older_version_location) + for f in older_version_files: + file_utils.delete_file(f, f"failed to delete older version file {f}", warning=True) + + # write executable in the install location + file_utils.write_to_file(install_location, 'wb', response_content, "Failed to create client proxy file. ") + os.chmod(install_location, os.stat(install_location).st_mode | stat.S_IXUSR) + + return install_location + + +def _get_proxy_filename_and_url(arc_proxy_folder): + import platform + operating_system = platform.system() + machine = platform.machine() + + logger.debug("Platform OS: %s", operating_system) + logger.debug("Platform architecture: %s", machine) + + if machine.endswith('64'): + architecture = 'amd64' + elif machine.endswith('86'): + architecture = '386' + elif machine == '': + raise azclierror.BadRequestError("Couldn't identify the platform architecture.") + else: + telemetry.set_exception(exception='Unsuported architecture for installing proxy', + fault_type=consts.PROXY_UNSUPPORTED_ARCH_FAULT_TYPE, + summary=f'{machine} is not supported for installing client proxy') + raise azclierror.BadRequestError(f"Unsuported architecture: {machine} is not currently supported") + + # define the request url and install location based on the os and architecture + proxy_name = f"sshProxy_{operating_system.lower()}_{architecture}" + request_uri = (f"{consts.CLIENT_PROXY_STORAGE_URL}/{consts.CLIENT_PROXY_RELEASE}" + f"/{proxy_name}_{consts.CLIENT_PROXY_VERSION}") + install_location = proxy_name + "_" + consts.CLIENT_PROXY_VERSION.replace('.', '_') + older_location = proxy_name + "*" + + if operating_system == 'Windows': + request_uri = request_uri + ".exe" + install_location = install_location + ".exe" + older_location = older_location + ".exe" + elif operating_system not in ('Linux', 'Darwin'): + telemetry.set_exception(exception='Unsuported OS for installing ssh client proxy', + fault_type=consts.PROXY_UNSUPPORTED_OS_FAULT_TYPE, + summary=f'{operating_system} is not supported for installing client proxy') + raise azclierror.BadRequestError(f"Unsuported OS: {operating_system} platform is not currently supported") + + if not arc_proxy_folder: + install_location = os.path.expanduser(os.path.join('~', os.path.join(".clientsshproxy", install_location))) + older_location = os.path.expanduser(os.path.join('~', os.path.join(".clientsshproxy", older_location))) + else: + install_location = os.path.join(arc_proxy_folder, install_location) + older_location = os.path.join(arc_proxy_folder, older_location) + + return request_uri, install_location, older_location + + +def format_relay_info_string(relay_info): + relay_info_string = json.dumps( + { + "relay": { + "namespaceName": relay_info.namespace_name, + "namespaceNameSuffix": relay_info.namespace_name_suffix, + "hybridConnectionName": relay_info.hybrid_connection_name, + "accessKey": relay_info.access_key, + "expiresOn": relay_info.expires_on + } + }) + result_bytes = relay_info_string.encode("ascii") + enc = base64.b64encode(result_bytes) + base64_result_string = enc.decode("ascii") + return base64_result_string diff --git a/src/ssh/azext_ssh/constants.py b/src/ssh/azext_ssh/constants.py index 1731ec2df75..b2f364670f2 100644 --- a/src/ssh/azext_ssh/constants.py +++ b/src/ssh/azext_ssh/constants.py @@ -9,7 +9,9 @@ CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120 CLEANUP_TIME_INTERVAL_IN_SECONDS = 10 CLEANUP_AWAIT_TERMINATION_IN_SECONDS = 30 +RELAY_INFO_MAXIMUM_DURATION_IN_SECONDS = 3600 PROXY_UNSUPPORTED_ARCH_FAULT_TYPE = 'client-proxy-unsupported-architecture-error' PROXY_UNSUPPORTED_OS_FAULT_TYPE = 'client-proxy-unsupported-os-error' PROXY_DOWNLOAD_FAILED_FAULT_TYPE = 'client-proxy-download-failed-error' LIST_CREDENTIALS_FAILED_FAULT_TYPE = 'get-relay-information-failed-error' + diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 8ae5feb3e4f..6c34dc8e814 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -8,28 +8,24 @@ import hashlib import json import tempfile -import urllib.request -import base64 -import stat import time -from glob import glob from knack import log from azure.cli.core import azclierror from azure.cli.core import telemetry +from azure.core.exceptions import ResourceNotFoundError, HttpResponseError from . import ip_utils from . import rsa_parser from . import ssh_utils -from . import constants as consts -from . import file_utils +from . import connectivity_utils logger = log.get_logger(__name__) def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, local_user=None, cert_file=None, port=None, - ssh_client_path=None, delete_credentials=False, resource_type=None, ssh_args=None): + ssh_client_path=None, delete_credentials=False, resource_type=None, ssh_proxy_folder=None, ssh_args=None): if delete_credentials and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0": raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.") @@ -37,20 +33,23 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ _assert_args(resource_group_name, vm_name, ssh_ip, resource_type, cert_file, local_user) credentials_folder = None do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, None, None, - ssh_client_path, ssh_args, delete_credentials, credentials_folder) + ssh_client_path, ssh_args, delete_credentials, credentials_folder, local_user) do_ssh_op(cmd, vm_name, resource_group_name, ssh_ip, public_key_file, private_key_file, local_user, - cert_file, port, use_private_ip, credentials_folder) + cert_file, port, use_private_ip, credentials_folder, ssh_proxy_folder) def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, - local_user=None, cert_file=None, port=None, resource_type=None, credentials_folder=None): + local_user=None, cert_file=None, port=None, resource_type=None, credentials_folder=None, + ssh_proxy_folder=None): if (public_key_file or private_key_file) and credentials_folder: raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with " "--public-key-file/-p or --private-key-file/-i.") _assert_args(resource_group_name, vm_name, ssh_ip, resource_type, cert_file, local_user) + config_path = os.path.abspath(config_path) + # Default credential location if not credentials_folder: config_folder = os.path.dirname(config_path) @@ -64,9 +63,9 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, config_path, overwrite, - None, None, False, credentials_folder) + None, None, False, credentials_folder, local_user) do_ssh_op(cmd, vm_name, resource_group_name, ssh_ip, public_key_file, private_key_file, local_user, - cert_file, port, use_private_ip, credentials_folder) + cert_file, port, use_private_ip, credentials_folder, ssh_proxy_folder) def ssh_cert(cmd, cert_path=None, public_key_file=None): @@ -82,11 +81,17 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): "is no longer being used.", keys_folder) public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) - print(cert_file + "\n") + try: + cert_expiration = ssh_utils.get_certificate_start_and_end_times(cert_file)[1] + print(f"Generated SSH certificate {cert_file} is valid until {cert_expiration}.") + except Exception as e: + logger.warning("Couldn't determine certificate validity. Error: %s", str(e)) + print(cert_file + "\n") def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, private_key_file=None, - local_user=None, cert_file=None, port=None, ssh_client_path=None, delete_credentials=False, ssh_args=None): + local_user=None, cert_file=None, port=None, ssh_client_path=None, delete_credentials=False, + ssh_proxy_folder=None, ssh_args=None): if delete_credentials and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0": raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.") @@ -95,37 +100,60 @@ def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, p credentials_folder = None + arc, arc_error, is_arc_server = _check_if_arc_server(cmd, resource_group_name, vm_name) + if not is_arc_server: + if isinstance(arc_error, ResourceNotFoundError): + raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group " + f"{resource_group_name} was not found. Error:\n" + f"{str(arc_error)}") + raise azclierror.BadRequestError("Unable to determine that the target machine is an Arc Server. " + f"Error:\n{str(arc_error)}") + if arc and arc.properties and arc.properties and arc.properties.os_name: + os_type = arc.properties.os_name + # Note: This is a temporary check while AAD login is not enabled for Windows. + if os_type.lower() == 'windows' and not local_user: + raise azclierror.RequiredArgumentMissingError("SSH Login to AAD user is not currently supported for Windows. " + "Please provide --local-user.") + op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args, delete_credentials=delete_credentials) _do_ssh_op(cmd, vm_name, resource_group_name, None, public_key_file, private_key_file, local_user, cert_file, port, - False, credentials_folder, op_call, True) + False, credentials_folder, ssh_proxy_folder, op_call, True) def _do_ssh_op(cmd, vm_name, resource_group_name, ssh_ip, public_key_file, private_key_file, username, - cert_file, port, use_private_ip, credentials_folder, op_call, is_arc): + cert_file, port, use_private_ip, credentials_folder, ssh_proxy_folder, op_call, is_arc): - proxy_path = None - relay_info = None - if is_arc: - proxy_path = _arc_get_client_side_proxy() - relay_info = _arc_list_access_details(cmd, resource_group_name, vm_name) - else: - ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group_name, vm_name, use_private_ip) - if not ssh_ip: - if not use_private_ip: - raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public IP address to SSH to") - raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to" - "SSH to") + if not is_arc and ssh_proxy_folder: + logger.warning("Target machine is not an Arc Server, --ssh-proxy-folder value will be ignored.") # If user provides local user, no credentials should be deleted. delete_keys = False delete_cert = False + cert_lifetime = None if not username: delete_cert = True public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) + if is_arc: + try: + cert_lifetime = ssh_utils.get_certificate_lifetime(cert_file).total_seconds() + except Exception as e: + logger.warning("Couldn't determine certificate expiration. Error: %s", str(e)) + + proxy_path = None + relay_info = None + if is_arc: + proxy_path = connectivity_utils.get_client_side_proxy(ssh_proxy_folder) + relay_info = connectivity_utils.get_relay_information(cmd, resource_group_name, vm_name, cert_lifetime) + else: + ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group_name, vm_name, use_private_ip) + if not ssh_ip: + if not use_private_ip: + raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public IP address to SSH to.") + raise azclierror.ResourceNotFoundError("Internal Error. Couldn't determine the IP address.") op_call(relay_info, proxy_path, vm_name, ssh_ip, username, cert_file, private_key_file, port, is_arc, delete_keys, delete_cert, public_key_file) @@ -192,7 +220,8 @@ def _prepare_jwk_data(public_key_file): def _assert_args(resource_group, vm_name, ssh_ip, resource_type, cert_file, username): - if resource_type and resource_type != "Microsoft.Compute" and resource_type != "Microsoft.HybridCompute": + if resource_type and resource_type.lower() != "microsoft.compute" \ + and resource_type.lower() != "microsoft.hybridcompute": raise azclierror.InvalidArgumentValueError("--resource-type must be either \"Microsoft.Compute\" " "for Azure VMs or \"Microsoft.HybridCompute\" for Arc Servers.") @@ -279,147 +308,66 @@ def _get_modulus_exponent(public_key_file): return modulus, exponent -# Downloads client side proxy to connect to Arc Connectivity Platform -def _arc_get_client_side_proxy(): - import platform - operating_system = platform.system() - machine = platform.machine() - - logger.debug("Platform OS: %s", operating_system) - logger.debug("Platform architecture: %s", machine) - - if machine.endswith('64'): - architecture = 'amd64' - elif machine.endswith('86'): - architecture = '386' - elif machine == '': - raise azclierror.BadRequestError("Couldn't identify the platform architecture.") - else: - telemetry.set_exception(exception='Unsuported architecture for installing proxy', - fault_type=consts.PROXY_UNSUPPORTED_ARCH_FAULT_TYPE, - summary=f'{machine} is not supported for installing client proxy') - raise azclierror.BadRequestError(f"Unsuported architecture: {machine} is not currently supported") - - # define the request url and install location based on the os and architecture - proxy_name = f"sshProxy_{operating_system.lower()}_{architecture}" - request_uri = (f"{consts.CLIENT_PROXY_STORAGE_URL}/{consts.CLIENT_PROXY_RELEASE}" - f"/{proxy_name}_{consts.CLIENT_PROXY_VERSION}") - install_location = os.path.join(".clientsshproxy", proxy_name + "_" + consts.CLIENT_PROXY_VERSION.replace('.', '_')) - older_version_location = os.path.join(".clientsshproxy", proxy_name + "*") - - if operating_system == 'Windows': - request_uri = request_uri + ".exe" - install_location = install_location + ".exe" - older_version_location = older_version_location + ".exe" - elif operating_system not in ('Linux', 'Darwin'): - telemetry.set_exception(exception='Unsuported OS for installing ssh client proxy', - fault_type=consts.PROXY_UNSUPPORTED_OS_FAULT_TYPE, - summary=f'{operating_system} is not supported for installing client proxy') - raise azclierror.BadRequestError(f"Unsuported OS: {operating_system} platform is not currently supported") - - install_location = os.path.expanduser(os.path.join('~', install_location)) - older_version_location = os.path.expanduser(os.path.join('~', older_version_location)) - install_dir = os.path.dirname(install_location) - - # Only download new proxy if it doesn't exist already - if not os.path.isfile(install_location): - t0 = time.time() - # download the executable - try: - with urllib.request.urlopen(request_uri) as response: - response_content = response.read() - response.close() - except Exception as e: - telemetry.set_exception(exception=e, fault_type=consts.PROXY_DOWNLOAD_FAILED_FAULT_TYPE, - summary=f'Failed to download proxy from {request_uri}') - raise azclierror.ClientRequestError(f"Failed to download client proxy executable from {request_uri}. " - "Error: " + str(e)) from e - time_elapsed = time.time() - t0 - - proxy_data = { - 'Context.Default.AzureCLI.SSHProxyDownloadTime': time_elapsed, - 'Context.Default.AzureCLI.SSHProxyVersion': consts.CLIENT_PROXY_VERSION - } - telemetry.add_extension_event('ssh', proxy_data) - - # if directory doesn't exist, create it - if not os.path.exists(install_dir): - file_utils.create_directory(install_dir, f"Failed to create client proxy directory '{install_dir}'. ") - # if directory exists, delete any older versions of the proxy - else: - older_version_files = glob(older_version_location) - for f in older_version_files: - file_utils.delete_file(f, f"failed to delete older version file {f}", warning=True) - - # write executable in the install location - file_utils.write_to_file(install_location, 'wb', response_content, "Failed to create client proxy file. ") - os.chmod(install_location, os.stat(install_location).st_mode | stat.S_IXUSR) - - return install_location - - -# Get the Access Details to connect to Arc Connectivity platform from the HybridConnectivity RP -def _arc_list_access_details(cmd, resource_group, vm_name): - from azext_ssh._client_factory import cf_endpoint - client = cf_endpoint(cmd.cli_ctx) - try: - t0 = time.time() - result = client.list_credentials(resource_group_name=resource_group, machine_name=vm_name, - endpoint_name="default") - time_elapsed = time.time() - t0 - telemetry.add_extension_event('ssh', {'Context.Default.AzureCLI.SSHListCredentialsTime': time_elapsed}) - except Exception as e: - telemetry.set_exception(exception='Call to listCredentials failed', - fault_type=consts.LIST_CREDENTIALS_FAILED_FAULT_TYPE, - summary=f'listCredentials failed with error: {str(e)}.') - raise azclierror.ClientRequestError(f"Request for Azure Relay Information Failed: {str(e)}") - - result_string = json.dumps( - { - "relay": { - "namespaceName": result.namespace_name, - "namespaceNameSuffix": result.namespace_name_suffix, - "hybridConnectionName": result.hybrid_connection_name, - "accessKey": result.access_key, - "expiresOn": result.expires_on - } - }) - result_bytes = result_string.encode("ascii") - enc = base64.b64encode(result_bytes) - base64_result_string = enc.decode("ascii") - return base64_result_string - - def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, config_path, overwrite, - ssh_client_path, ssh_args, delete_credentials, credentials_folder): + ssh_client_path, ssh_args, delete_credentials, credentials_folder, local_user): # If the user provides an IP address the target will be treated as an Azure VM even if it is an # Arc Server. Which just means that the Connectivity Proxy won't be used to establish connection. is_arc_server = False + is_azure_vm = False if ssh_ip: - is_arc_server = False + is_azure_vm = True + vm = None elif resource_type: - if resource_type == "Microsoft.HybridCompute": - is_arc_server = True + if resource_type.lower() == "microsoft.hybridcompute": + arc, arc_error, is_arc_server = _check_if_arc_server(cmd, resource_group_name, vm_name) + if not is_arc_server: + if isinstance(arc_error, ResourceNotFoundError): + raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group " + f"{resource_group_name} was not found. Error:\n" + f"{str(arc_error)}") + raise azclierror.BadRequestError("Unable to determine that the target machine is an Arc Server. " + f"Error:\n{str(arc_error)}") + + elif resource_type.lower() == "microsoft.compute": + vm, vm_error, is_azure_vm = _check_if_azure_vm(cmd, resource_group_name, vm_name) + if not is_azure_vm: + if isinstance(vm_error, ResourceNotFoundError): + raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group " + f"{resource_group_name} was not found. Error:\n" + f"{str(vm_error)}") + raise azclierror.BadRequestError("Unable to determine that the target machine is an Azure VM. " + f"Error:\n{str(vm_error)}") else: - vm_error, is_azure_vm = _check_if_azure_vm(cmd, resource_group_name, vm_name) - arc_error, is_arc_server = _check_if_arc_server(cmd, resource_group_name, vm_name) + vm, vm_error, is_azure_vm = _check_if_azure_vm(cmd, resource_group_name, vm_name) + arc, arc_error, is_arc_server = _check_if_arc_server(cmd, resource_group_name, vm_name) if is_azure_vm and is_arc_server: raise azclierror.BadRequestError(f"{resource_group_name} has Azure VM and Arc Server with the " f"same name: {vm_name}. Please provide a --resource-type.") if not is_azure_vm and not is_arc_server: - from azure.core.exceptions import ResourceNotFoundError if isinstance(arc_error, ResourceNotFoundError) and isinstance(vm_error, ResourceNotFoundError): raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group " - f"{resource_group_name} was not found. Erros:\n" - f"{str(arc_error)}\n{str(vm_error)}") + f"{resource_group_name} was not found.") raise azclierror.BadRequestError("Unable to determine the target machine type as Azure VM or " f"Arc Server. Errors:\n{str(arc_error)}\n{str(vm_error)}") + # Note: We are not able to determine the os of the target if the user only provides an IP address. + os_type = None + if is_azure_vm and vm and vm.storage_profile and vm.storage_profile.os_disk and vm.storage_profile.os_disk.os_type: + os_type = vm.storage_profile.os_disk.os_type + + if is_arc_server and arc and arc.properties and arc.properties and arc.properties.os_name: + os_type = arc.properties.os_name + + # Note 2: This is a temporary check while AAD login is not enabled for Windows. + if os_type and os_type.lower() == 'windows' and not local_user: + raise azclierror.RequiredArgumentMissingError("SSH Login to AAD user is not currently supported for Windows. " + "Please provide --local-user.") + if config_path: op_call = functools.partial(ssh_utils.write_ssh_config, config_path=config_path, overwrite=overwrite, resource_group=resource_group_name, credentials_folder=credentials_folder) @@ -433,27 +381,29 @@ def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, co def _check_if_azure_vm(cmd, resource_group_name, vm_name): from azure.cli.core.commands import client_factory from azure.cli.core import profiles - from azure.core.exceptions import ResourceNotFoundError, HttpResponseError + vm = None try: compute_client = client_factory.get_mgmt_service_client(cmd.cli_ctx, profiles.ResourceType.MGMT_COMPUTE) - compute_client.virtual_machines.get(resource_group_name, vm_name) + vm = compute_client.virtual_machines.get(resource_group_name, vm_name) except ResourceNotFoundError as e: - return e, False + return None, e, False # If user is not authorized to get the VM, it will throw a HttpResponseError except HttpResponseError as e: - return e, False - return None, True + return None, e, False + + return vm, None, True def _check_if_arc_server(cmd, resource_group_name, vm_name): - from azure.core.exceptions import ResourceNotFoundError, HttpResponseError from azext_ssh._client_factory import cf_machine client = cf_machine(cmd.cli_ctx) + arc = None try: - client.get(resource_group_name=resource_group_name, machine_name=vm_name) + arc = client.get(resource_group_name=resource_group_name, machine_name=vm_name) except ResourceNotFoundError as e: - return e, False + return None, e, False # If user is not authorized to get the arc server, it will throw a HttpResponseError except HttpResponseError as e: - return e, False - return None, True + return None, e, False + + return arc, None, True diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index c89bb1aeac5..b3e4bdcac14 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -8,6 +8,7 @@ import stat import multiprocessing as mp import time +import datetime import oschmod from knack import log @@ -15,6 +16,7 @@ from azure.cli.core import telemetry from . import file_utils +from . import connectivity_utils from . import constants as const logger = log.get_logger(__name__) @@ -34,7 +36,7 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil env = os.environ.copy() if is_arc: - env['SSHPROXY_RELAY_INFO'] = relay_info + env['SSHPROXY_RELAY_INFO'] = connectivity_utils.format_relay_info_string(relay_info) if port: pcommand = f"ProxyCommand={proxy_path} -p {port}" else: @@ -84,6 +86,34 @@ def get_ssh_cert_info(cert_file): return subprocess.check_output(command, shell=platform.system() == 'Windows').decode().splitlines() +def _get_ssh_cert_validity(cert_file): + if cert_file: + info = get_ssh_cert_info(cert_file) + for line in info: + if "Valid:" in line: + return line.strip() + return None + + +def get_certificate_start_and_end_times(cert_file): + validity_str = _get_ssh_cert_validity(cert_file) + times = None + if validity_str and "Valid: from " in validity_str and " to " in validity_str: + times = validity_str.replace("Valid: from ", "").split(" to ") + t0 = datetime.datetime.strptime(times[0], '%Y-%m-%dT%X') + t1 = datetime.datetime.strptime(times[1], '%Y-%m-%dT%X') + times = (t0, t1) + return times + + +def get_certificate_lifetime(cert_file): + times = get_certificate_start_and_end_times(cert_file) + lifetime = None + if times: + lifetime = times[1] - times[0] + return lifetime + + def get_ssh_cert_principals(cert_file): info = get_ssh_cert_info(cert_file) principals = [] @@ -100,15 +130,6 @@ def get_ssh_cert_principals(cert_file): return principals -def get_ssh_cert_validity(cert_file): - if cert_file: - info = get_ssh_cert_info(cert_file) - for line in info: - if "Valid:" in line: - return line.strip() - return None - - def write_ssh_config(relay_info, proxy_path, vm_name, ip, username, cert_file, private_key_file, port, is_arc, delete_keys, delete_cert, _, config_path, overwrite, resource_group, credentials_folder): @@ -116,25 +137,24 @@ def write_ssh_config(relay_info, proxy_path, vm_name, ip, username, common_lines = [] common_lines.append("\tUser " + username) if cert_file: - common_lines.append("\tCertificateFile " + cert_file) + common_lines.append("\tCertificateFile \"" + cert_file + "\"") if private_key_file: - common_lines.append("\tIdentityFile " + private_key_file) + common_lines.append("\tIdentityFile \"" + private_key_file + "\"") lines = [""] relay_info_path = None relay_info_filename = None if is_arc: - relay_info_path, relay_info_filename = _prepare_relay_info_file(relay_info, cert_file, - private_key_file, credentials_folder, + relay_info_path, relay_info_filename = _prepare_relay_info_file(relay_info, credentials_folder, vm_name, resource_group) lines.append("Host " + resource_group + "-" + vm_name) lines.append("\tHostName " + vm_name) lines = lines + common_lines if port: - lines.append("\tProxyCommand " + proxy_path + " " + "-r " + relay_info_path + " " + "-p " + port) + lines.append("\tProxyCommand \"" + proxy_path + "\" " + "-r \"" + relay_info_path + "\" " + "-p " + port) else: - lines.append("\tProxyCommand " + proxy_path + " " + "-r " + relay_info_path) + lines.append("\tProxyCommand \"" + proxy_path + "\" " + "-r \"" + relay_info_path + "\"") else: if resource_group and vm_name: lines.append("Host " + resource_group + "-" + vm_name) @@ -276,16 +296,11 @@ def _terminate_cleanup(delete_keys, delete_cert, delete_credentials, cleanup_pro file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) -def _prepare_relay_info_file(relay_info, cert_file, private_key_file, credentials_folder, vm_name, resource_group): - if cert_file: - relay_info_dir = os.path.dirname(cert_file) - elif private_key_file: - relay_info_dir = os.path.dirname(private_key_file) - else: - # create the custom folder - relay_info_dir = credentials_folder - if not os.path.isdir(relay_info_dir): - os.makedirs(relay_info_dir) +def _prepare_relay_info_file(relay_info, credentials_folder, vm_name, resource_group): + # create the custom folder + relay_info_dir = credentials_folder + if not os.path.isdir(relay_info_dir): + os.makedirs(relay_info_dir) if vm_name and resource_group: relay_info_filename = resource_group + "-" + vm_name + "-relay_info" @@ -293,14 +308,25 @@ def _prepare_relay_info_file(relay_info, cert_file, private_key_file, credential relay_info_path = os.path.join(relay_info_dir, relay_info_filename) # Overwrite relay_info if it already exists in that folder. file_utils.delete_file(relay_info_path, f"{relay_info_path} already exists, and couldn't be overwritten.") - file_utils.write_to_file(relay_info_path, 'w', relay_info, + file_utils.write_to_file(relay_info_path, 'w', connectivity_utils.format_relay_info_string(relay_info), f"Couldn't write relay information to file {relay_info_path}.", 'utf-8') oschmod.set_mode(relay_info_path, stat.S_IRUSR) + # Print the expiration of the relay information + expiration = datetime.datetime.fromtimestamp(relay_info.expires_on) + expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p") + print(f"Generated file with Relay Information {relay_info_path} is valid until {expiration}.\n") + return relay_info_path, relay_info_filename def _issue_config_cleanup_warning(delete_cert, delete_keys, is_arc, cert_file, relay_info_filename, relay_info_path): + if delete_cert: + expiration = get_certificate_start_and_end_times(cert_file)[1] + expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p") + print(f"Generated SSH certificate {cert_file} is valid until", + f"{expiration}.\n") + if delete_keys or delete_cert or is_arc: # Warn users to delete credentials once config file is no longer being used. # If user provided keys, only ask them to delete the certificate. @@ -309,8 +335,8 @@ def _issue_config_cleanup_warning(delete_cert, delete_keys, is_arc, cert_file, r path_to_delete = os.path.dirname(cert_file) items_to_delete = f" (id_rsa, id_rsa.pub, id_rsa.pub-aadcert.pub, {relay_info_filename})" elif delete_cert: - path_to_delete = os.path.dirname(cert_file) - items_to_delete = f" (id_rsa.pub-aadcert.pub, {relay_info_filename})" + path_to_delete = f"{cert_file} and {relay_info_path}" + items_to_delete = "" else: path_to_delete = relay_info_path items_to_delete = "" @@ -321,15 +347,8 @@ def _issue_config_cleanup_warning(delete_cert, delete_keys, is_arc, cert_file, r path_to_delete = cert_file items_to_delete = "" - validity_warning = "" - if delete_cert: - validity = get_ssh_cert_validity(cert_file) - if validity: - validity_warning = f" {validity.lower()}" - - logger.warning("%s contains sensitive information%s%s\n" - "Please delete it once you no longer need this config file. ", - path_to_delete, items_to_delete, validity_warning) + print(f"{path_to_delete} contain sensitive information{items_to_delete}. " + "Please delete it once you no longer need this config file.\n") def _get_connection_status(log_file): diff --git a/src/ssh/setup.py b/src/ssh/setup.py index 860e04cb9b8..ebba436252b 100644 --- a/src/ssh/setup.py +++ b/src/ssh/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages -VERSION = "0.2.1" +VERSION = "0.2.2" CLASSIFIERS = [ 'Development Status :: 4 - Beta',