From 7b79f28f2a406bc1d1f504cf9c5e90611efff2b1 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Fri, 15 Oct 2021 13:28:16 -0400 Subject: [PATCH 01/26] Initial changes to clean up keys once connection is established --- src/ssh/azext_ssh/custom.py | 5 ++- src/ssh/azext_ssh/file_utils.py | 16 +++++++ src/ssh/azext_ssh/ssh_utils.py | 79 ++++++++++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 036fc2aa751..2129ca44643 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -18,7 +18,10 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, port=None, ssh_args=None): - op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args) + delete_key = False + if not private_key_file and not public_key_file: + delete_key = True + op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args, delete_key) _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call) diff --git a/src/ssh/azext_ssh/file_utils.py b/src/ssh/azext_ssh/file_utils.py index 3c3bd795c15..b262fd96b8c 100644 --- a/src/ssh/azext_ssh/file_utils.py +++ b/src/ssh/azext_ssh/file_utils.py @@ -5,6 +5,11 @@ import errno import os +from knack import log + +from azure.cli.core import azclierror + +logger = log.get_logger(__name__) def make_dirs_for_file(file_path): @@ -20,3 +25,14 @@ def mkdir_p(path): pass else: raise + + +def delete_file(file_path, message, warning=False): + if os.path.isfile(file_path): + try: + os.remove(file_path) + except Exception as e: + if warning: + logger.warning(message) + else: + raise azclierror.FileOperationError(message + "Error: " + str(e)) from e diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index e48d1ab51a8..0e9ffd5a57e 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -5,22 +5,61 @@ import os import platform import subprocess +import time +import re +import multiprocessing as mp +from azext_ssh import file_utils from knack import log from azure.cli.core import azclierror logger = log.get_logger(__name__) +CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120 +CLEANUP_TIME_INTERVAL_IN_SECONDS = 10 -def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_file): + +def start_ssh_connection(port, ssh_args, delete_keys, ip, username, cert_file, private_key_file): ssh_arg_list = [] if ssh_args: ssh_arg_list = ssh_args + + log_file = None + if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): + # This means the user either provided his own client log file or that they + # want the client log messages to be printed to the console. + # In these two cases, we should not use the log files to check for connection success. + log_file_dir = os.path.dirname(cert_file) + log_file_name = 'ssh_client_log_' + str(os.getpid()) + log_file = os.path.join(log_file_dir, log_file_name) + ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] + print(f"Log file: {log_file}") + print(f"Certificate: {cert_file}") + command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list + + # Create a new process that will wait until the connection is established and then delete keys. + cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, cert_file, private_key_file, + log_file, True)) + cleanup_process.start() + logger.debug("Running ssh command %s", ' '.join(command)) subprocess.call(command, shell=platform.system() == 'Windows') + if cleanup_process.is_alive(): + print("Terminating cleanup") + cleanup_process.terminate() + while cleanup_process.is_alive(): + print("Waiting for cleanup process to die") + time.sleep(1) + # Make sure all files have been properly removed. + _do_cleanup(delete_keys, cert_file, private_key_file) + if log_file: + file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) + # Delete the temporary folder as well? + os.rmdir(os.path.dirname(cert_file)) + def create_ssh_keyfile(private_key_file): command = [_get_ssh_path("ssh-keygen"), "-f", private_key_file, "-t", "rsa", "-q", "-N", ""] @@ -53,6 +92,11 @@ def get_ssh_cert_principals(cert_file): def write_ssh_config(config_path, resource_group, vm_name, overwrite, ip, username, cert_file, private_key_file): + logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " + "managing/deleting the private key and signed public key once this config file is no " + "longer being used. Please delete the contents of %s once you no longer need this config file.", + os.path.dirname(cert_file)) + lines = [""] if resource_group and vm_name: @@ -117,3 +161,36 @@ def _build_args(cert_file, private_key_file, port): port_arg = ["-p", port] certificate = ["-o", "CertificateFile=" + cert_file] return private_key + certificate + port_arg + + +def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): + # if there is a log file, use it to check for the connection success + print(f"Cleanup launched. Log file: {log_file}") + if log_file: + t0 = time.time() + match = False + while (time.time() - t0) < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS and not match: + time.sleep(CLEANUP_TIME_INTERVAL_IN_SECONDS) + try: + with open(log_file, 'r') as ssh_client_log: + for line in ssh_client_log: + if re.search("debug1: Authentication succeeded", line): + match = True + ssh_client_log.close() + except: + print("Can't open log, waiting two minutes") + t1 = time.time() - t0 + if t1 < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: + time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS - t1) + elif wait: + print("No log. Waiting two minutes") + # if we are not checking the logs, but still want to wait for connection before deleting files + time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) + + print("Deleting") + if delete_keys and private_key: + public_key = private_key + '.pub' + file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) + file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) + if cert_file: + file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) From fd0a29d4886cf0fda32ccd0df3d4de3117e739be Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 26 Oct 2021 17:41:50 -0400 Subject: [PATCH 02/26] Add new parameter to config to decide credentials dir. Establish default cred location for config --- src/ssh/azext_ssh/_params.py | 3 +++ src/ssh/azext_ssh/custom.py | 45 ++++++++++++++++++++++------------ src/ssh/azext_ssh/ssh_utils.py | 33 ++++++++++++------------- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 9a4d4c06a94..e432dd50581 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -27,6 +27,9 @@ def load_arguments(self, _): help='Will use a private IP if available. By default only public IPs are used.') c.argument('overwrite', action='store_true', options_list=['--overwrite'], help='Overwrites the config file if this flag is set') + c.argument('credentials_folder', options_list=['--credentials-destination-folder'], + help='Folder where credentials will be stored.') + with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 2129ca44643..bde0536b563 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -18,18 +18,24 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, port=None, ssh_args=None): - delete_key = False - if not private_key_file and not public_key_file: - delete_key = True - op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args, delete_key) + credentials_folder = None + is_config = False + op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args) _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, - public_key_file, private_key_file, use_private_ip, op_call) + public_key_file, private_key_file, use_private_ip, credentials_folder, is_config, op_call) def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, - public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False): + public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, credentials_folder=None): op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call) + is_config = True + # Default credential location + if not credentials_folder: + config_folder = os.path.dirname(config_path) + folder_name = resource_group_name + "-" + vm_name + credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) + + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, is_config, op_call) def ssh_cert(cmd, cert_path=None, public_key_file=None): @@ -38,9 +44,8 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): print(cert_file + "\n") -def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call): +def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, is_config, op_call): _assert_args(resource_group, vm_name, ssh_ip) - public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file, private_key_file) ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) if not ssh_ip: @@ -49,8 +54,10 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") + # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys + public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, is_config) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) - op_call(ssh_ip, username, cert_file, private_key_file) + op_call(ssh_ip, username, cert_file, private_key_file, delete_keys) def _get_and_write_certificate(cmd, public_key_file, cert_file): @@ -123,12 +130,20 @@ def _assert_args(resource_group, vm_name, ssh_ip): "--ip cannot be used with --resource-group or --vm-name/--name") -def _check_or_create_public_private_files(public_key_file, private_key_file): +def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, is_config): + delete_keys = False # If nothing is passed in create a temporary directory with a ephemeral keypair if not public_key_file and not private_key_file: - temp_dir = tempfile.mkdtemp(prefix="aadsshcert") - public_key_file = os.path.join(temp_dir, "id_rsa.pub") - private_key_file = os.path.join(temp_dir, "id_rsa") + # We only want to delete the keys after the connection if the user hasn't providede their own keys + delete_keys = True + if not is_config: + # az ssh vm: Create keys on temp folder and delete folder once connection is established. + credentials_folder = tempfile.mkdtemp(prefix="aadsshcert") + else: + if not os.path.isdir(credentials_folder): + os.makedirs(credentials_folder) + public_key_file = os.path.join(credentials_folder, "id_rsa.pub") + private_key_file = os.path.join(credentials_folder, "id_rsa") ssh_utils.create_ssh_keyfile(private_key_file) if not public_key_file: @@ -146,7 +161,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file): if not os.path.isfile(private_key_file): raise azclierror.FileOperationError(f"Private key file {private_key_file} not found") - return public_key_file, private_key_file + return public_key_file, private_key_file, delete_keys def _write_cert_file(certificate_contents, cert_file): diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 0e9ffd5a57e..b738cb256fe 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -19,22 +19,20 @@ CLEANUP_TIME_INTERVAL_IN_SECONDS = 10 -def start_ssh_connection(port, ssh_args, delete_keys, ip, username, cert_file, private_key_file): +def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_file, delete_keys): ssh_arg_list = [] if ssh_args: ssh_arg_list = ssh_args log_file = None if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - # This means the user either provided his own client log file or that they - # want the client log messages to be printed to the console. - # In these two cases, we should not use the log files to check for connection success. + # If the user either provided his own client log file (-E) or + # want the client log messages to be printed to the console (-vvv/-vv/-v), + # we should not use the log files to check for connection success. log_file_dir = os.path.dirname(cert_file) log_file_name = 'ssh_client_log_' + str(os.getpid()) log_file = os.path.join(log_file_dir, log_file_name) ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] - print(f"Log file: {log_file}") - print(f"Certificate: {cert_file}") command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list @@ -48,10 +46,8 @@ def start_ssh_connection(port, ssh_args, delete_keys, ip, username, cert_file, p subprocess.call(command, shell=platform.system() == 'Windows') if cleanup_process.is_alive(): - print("Terminating cleanup") cleanup_process.terminate() while cleanup_process.is_alive(): - print("Waiting for cleanup process to die") time.sleep(1) # Make sure all files have been properly removed. _do_cleanup(delete_keys, cert_file, private_key_file) @@ -90,12 +86,18 @@ def get_ssh_cert_principals(cert_file): def write_ssh_config(config_path, resource_group, vm_name, overwrite, - ip, username, cert_file, private_key_file): + ip, username, cert_file, private_key_file, delete_keys): - logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " - "managing/deleting the private key and signed public key once this config file is no " - "longer being used. Please delete the contents of %s once you no longer need this config file.", - os.path.dirname(cert_file)) + if delete_keys: + logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " + "managing/deleting the private key and signed public key once this config file is no " + "longer being used. Please delete the contents of %s once you no longer need this config file.", + os.path.dirname(cert_file)) + else: + # Delete keys is false when user hasn't provided their own key pair. Only request deletion of certificate in that case. + logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " + "managing/deleting the signed public key once this config file is no longer being used. " + "Please delete %s once you no longer need this config file.", cert_file) lines = [""] @@ -165,7 +167,6 @@ def _build_args(cert_file, private_key_file, port): def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): # if there is a log file, use it to check for the connection success - print(f"Cleanup launched. Log file: {log_file}") if log_file: t0 = time.time() match = False @@ -178,16 +179,14 @@ def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): match = True ssh_client_log.close() except: - print("Can't open log, waiting two minutes") + # Can't open log, wait for two minutes t1 = time.time() - t0 if t1 < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS - t1) elif wait: - print("No log. Waiting two minutes") # if we are not checking the logs, but still want to wait for connection before deleting files time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) - print("Deleting") if delete_keys and private_key: public_key = private_key + '.pub' file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) From 9abf7ba3740c70d1d27b151f4450b222a9aad261 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 27 Oct 2021 16:24:13 -0400 Subject: [PATCH 03/26] Not try to delete folder when user provided keys --- src/ssh/azext_ssh/custom.py | 2 +- src/ssh/azext_ssh/ssh_utils.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index bde0536b563..fa08a7f0d4c 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -134,7 +134,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre delete_keys = False # If nothing is passed in create a temporary directory with a ephemeral keypair if not public_key_file and not private_key_file: - # We only want to delete the keys after the connection if the user hasn't providede their own keys + # We only want to delete the keys if the user hasn't providede their own keys delete_keys = True if not is_config: # az ssh vm: Create keys on temp folder and delete folder once connection is established. diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index b738cb256fe..990d55004d0 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -26,14 +26,16 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi log_file = None if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - # If the user either provided his own client log file (-E) or - # want the client log messages to be printed to the console (-vvv/-vv/-v), + # If the user either provides his own client log file (-E) or + # wants the client log messages to be printed to the console (-vvv/-vv/-v), # we should not use the log files to check for connection success. log_file_dir = os.path.dirname(cert_file) log_file_name = 'ssh_client_log_' + str(os.getpid()) log_file = os.path.join(log_file_dir, log_file_name) ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] + print(cert_file) + command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list @@ -54,7 +56,8 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi if log_file: file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) # Delete the temporary folder as well? - os.rmdir(os.path.dirname(cert_file)) + if delete_keys: + os.rmdir(os.path.dirname(cert_file)) def create_ssh_keyfile(private_key_file): From 0f0f87d78db43ec9c47a134c358d4070e8583a04 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 27 Oct 2021 16:32:05 -0400 Subject: [PATCH 04/26] Deleting debug print statements --- src/ssh/azext_ssh/ssh_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 990d55004d0..e8781a73a4b 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -33,8 +33,6 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi log_file_name = 'ssh_client_log_' + str(os.getpid()) log_file = os.path.join(log_file_dir, log_file_name) ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] - - print(cert_file) command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list @@ -97,7 +95,7 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, "longer being used. Please delete the contents of %s once you no longer need this config file.", os.path.dirname(cert_file)) else: - # Delete keys is false when user hasn't provided their own key pair. Only request deletion of certificate in that case. + # Delete keys is false when user provided their own key pair. Only request deletion of certificate in that case. logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " "managing/deleting the signed public key once this config file is no longer being used. " "Please delete %s once you no longer need this config file.", cert_file) From 67bee4ef8f4cc7471756e47e237ee9cd0aaa079e Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Thu, 28 Oct 2021 15:17:27 -0400 Subject: [PATCH 05/26] Addressing review comments --- src/ssh/azext_ssh/_params.py | 4 +-- src/ssh/azext_ssh/custom.py | 24 ++++++++++-------- src/ssh/azext_ssh/file_utils.py | 11 ++++++++ src/ssh/azext_ssh/ssh_utils.py | 45 +++++++++++++++++---------------- 4 files changed, 49 insertions(+), 35 deletions(-) diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index e432dd50581..eb8cb0b2593 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -27,8 +27,8 @@ def load_arguments(self, _): help='Will use a private IP if available. By default only public IPs are used.') c.argument('overwrite', action='store_true', options_list=['--overwrite'], help='Overwrites the config file if this flag is set') - c.argument('credentials_folder', options_list=['--credentials-destination-folder'], - help='Folder where credentials will be stored.') + c.argument('credentials_folder', options_list=['--keys-destination-folder'], + help='Folder where new generated keys will be stored.') with self.argument_context('ssh cert') as c: diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index fa08a7f0d4c..0fd48577c98 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -19,33 +19,35 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, port=None, ssh_args=None): credentials_folder = None - is_config = False op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args) _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, - public_key_file, private_key_file, use_private_ip, credentials_folder, is_config, op_call) + public_key_file, private_key_file, use_private_ip, credentials_folder, op_call) def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, credentials_folder=None): + # If user provides their own key pair, certificate will be written in the same folder as public key. + if (public_key_file or private_key_file) and credentials_folder: + raise azclierror.InvalidArgumentValueError("If providing --public-key-file/-p or --private-key-file/-i, --keys-destination-folder should not be used") op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) - is_config = True # Default credential location if not credentials_folder: config_folder = os.path.dirname(config_path) folder_name = resource_group_name + "-" + vm_name credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, is_config, op_call) + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call) def ssh_cert(cmd, cert_path=None, public_key_file=None): - public_key_file, _ = _check_or_create_public_private_files(public_key_file, None) + public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, None) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) print(cert_file + "\n") -def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, is_config, op_call): +def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call): _assert_args(resource_group, vm_name, ssh_ip) + # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) if not ssh_ip: @@ -54,8 +56,7 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") - # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys - public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, is_config) + public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) op_call(ssh_ip, username, cert_file, private_key_file, delete_keys) @@ -130,14 +131,15 @@ def _assert_args(resource_group, vm_name, ssh_ip): "--ip cannot be used with --resource-group or --vm-name/--name") -def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, is_config): +def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder): delete_keys = False # If nothing is passed in create a temporary directory with a ephemeral keypair if not public_key_file and not private_key_file: # We only want to delete the keys if the user hasn't providede their own keys delete_keys = True - if not is_config: - # az ssh vm: Create keys on temp folder and delete folder once connection is established. + if not credentials_folder: + # az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails. + # az ssh cert: If user didn't provide public key, save it to temp folder credentials_folder = tempfile.mkdtemp(prefix="aadsshcert") else: if not os.path.isdir(credentials_folder): diff --git a/src/ssh/azext_ssh/file_utils.py b/src/ssh/azext_ssh/file_utils.py index b262fd96b8c..0cbf7be373f 100644 --- a/src/ssh/azext_ssh/file_utils.py +++ b/src/ssh/azext_ssh/file_utils.py @@ -36,3 +36,14 @@ def delete_file(file_path, message, warning=False): logger.warning(message) else: raise azclierror.FileOperationError(message + "Error: " + str(e)) from e + + +def delete_folder(dir_path, message, warning=False): + if os.path.isdir(dir_path): + try: + os.rmdir(dir_path) + except Exception as e: + if warning: + logger.warning(message) + else: + raise azclierror.FileOperationError(message + "Error: " + str(e)) from e \ No newline at end of file diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index e8781a73a4b..9c69cfcaee4 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -17,9 +17,11 @@ CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120 CLEANUP_TIME_INTERVAL_IN_SECONDS = 10 +CLEANUP_AWAIT_TERMINATION_IN_SECONDS = 30 def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_file, delete_keys): + ssh_arg_list = [] if ssh_args: ssh_arg_list = ssh_args @@ -47,15 +49,18 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi if cleanup_process.is_alive(): cleanup_process.terminate() - while cleanup_process.is_alive(): + # wait for process to terminate + t0 = time.time() + while cleanup_process.is_alive() and (time.time() - t0) < CLEANUP_AWAIT_TERMINATION_IN_SECONDS: time.sleep(1) + # Make sure all files have been properly removed. _do_cleanup(delete_keys, cert_file, private_key_file) if log_file: file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) - # Delete the temporary folder as well? if delete_keys: - os.rmdir(os.path.dirname(cert_file)) + temp_dir = os.path.dirname(cert_file) + file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) def create_ssh_keyfile(private_key_file): @@ -89,16 +94,13 @@ def get_ssh_cert_principals(cert_file): def write_ssh_config(config_path, resource_group, vm_name, overwrite, ip, username, cert_file, private_key_file, delete_keys): - if delete_keys: - logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " - "managing/deleting the private key and signed public key once this config file is no " - "longer being used. Please delete the contents of %s once you no longer need this config file.", - os.path.dirname(cert_file)) - else: - # Delete keys is false when user provided their own key pair. Only request deletion of certificate in that case. - logger.warning("Sensitive information for authentication is being stored on disk. You are responsible for " - "managing/deleting the signed public key once this config file is no longer being used. " - "Please delete %s once you no longer need this config file.", cert_file) + # Warn users to delete credentials once config file is no longer being used. + # If user provided keys, only ask them to delete the certificate. + path_to_delete = os.path.dirname(cert_file) + if not delete_keys: + path_to_delete = cert_file + logger.warning("%s contains sensitive information, please delete it once you no longer need this config file.", + path_to_delete) lines = [""] @@ -169,25 +171,24 @@ def _build_args(cert_file, private_key_file, port): def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): # if there is a log file, use it to check for the connection success if log_file: + time.sleep(500) t0 = time.time() match = False while (time.time() - t0) < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS and not match: time.sleep(CLEANUP_TIME_INTERVAL_IN_SECONDS) try: with open(log_file, 'r') as ssh_client_log: - for line in ssh_client_log: - if re.search("debug1: Authentication succeeded", line): - match = True - ssh_client_log.close() + match = "debug1: Authentication succeeded" in ssh_client_log.read() + ssh_client_log.close() except: - # Can't open log, wait for two minutes - t1 = time.time() - t0 - if t1 < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: - time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS - t1) + # If there is an exception, wait for a little bit and try again + time.sleep(CLEANUP_TIME_INTERVAL_IN_SECONDS) + elif wait: # if we are not checking the logs, but still want to wait for connection before deleting files time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) - + + # TO DO: Once arc changes are merged, delete relay information as well if delete_keys and private_key: public_key = private_key + '.pub' file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) From c4076fc155cc92d8eebd7ae0543bbb5c8164de58 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Thu, 28 Oct 2021 16:03:10 -0400 Subject: [PATCH 06/26] Added certificate validity to config warning --- src/ssh/azext_ssh/ssh_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 9c69cfcaee4..220d2422364 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -91,6 +91,13 @@ def get_ssh_cert_principals(cert_file): return principals +def get_ssh_cert_validity(cert_file): + info = get_ssh_cert_info(cert_file) + for line in info: + if "Valid:" in line: + return line.strip() + + def write_ssh_config(config_path, resource_group, vm_name, overwrite, ip, username, cert_file, private_key_file, delete_keys): @@ -99,8 +106,8 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, path_to_delete = os.path.dirname(cert_file) if not delete_keys: path_to_delete = cert_file - logger.warning("%s contains sensitive information, please delete it once you no longer need this config file.", - path_to_delete) + logger.warning("%s contains sensitive information, please delete it once you no longer need this config file. " + "The signed public key %s is %s", path_to_delete, cert_file, get_ssh_cert_validity(cert_file)) lines = [""] From 1939dabf5020cb78e6995b82bbf68afdc6afb512 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Fri, 29 Oct 2021 13:07:00 -0400 Subject: [PATCH 07/26] update history and version --- src/ssh/HISTORY.md | 6 ++++++ src/ssh/azext_ssh/custom.py | 2 +- src/ssh/setup.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 01034ca9e55..4849ac34bb2 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,5 +1,11 @@ Release History =============== +0.1.9 +----- +* Delete all keys and certificates created during execution of az ssh vm. +* Add --keys-destination-folder to az ssh config +* By default, save keys created during az ssh config in a directory in the same location as --file + 0.1.8 ----- * Rollback from version 0.1.7 to 0.1.6 to remove preview features. diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 0fd48577c98..22218afad5f 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -133,7 +133,7 @@ def _assert_args(resource_group, vm_name, ssh_ip): def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder): delete_keys = False - # If nothing is passed in create a temporary directory with a ephemeral keypair + # If nothing is passed in create a directory with a ephemeral keypair if not public_key_file and not private_key_file: # We only want to delete the keys if the user hasn't providede their own keys delete_keys = True diff --git a/src/ssh/setup.py b/src/ssh/setup.py index 22037fe5f3f..0e337c8b51a 100644 --- a/src/ssh/setup.py +++ b/src/ssh/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages -VERSION = "0.1.8" +VERSION = "0.1.9" CLASSIFIERS = [ 'Development Status :: 4 - Beta', From bf324b87800627d5ef775bca24011a219dc09571 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 2 Nov 2021 13:30:46 -0400 Subject: [PATCH 08/26] changes to az ssh cert --- src/ssh/azext_ssh/_params.py | 3 ++- src/ssh/azext_ssh/custom.py | 14 +++++++++++++- src/ssh/azext_ssh/ssh_utils.py | 6 +++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index eb8cb0b2593..3ab296aba0d 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -34,4 +34,5 @@ def load_arguments(self, _): with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appened') - c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path') + c.argument('public_key_file', options_list=['--public-key-file', '-p'], + help='The RSA public key file path. If not provided, new key pair is created in the same directpry as --file') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 22218afad5f..d03fddbcd67 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -9,8 +9,11 @@ import json import tempfile +from knack import log from azure.cli.core import azclierror +logger = log.get_logger(__name__) + from . import ip_utils from . import rsa_parser from . import ssh_utils @@ -40,7 +43,16 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= def ssh_cert(cmd, cert_path=None, public_key_file=None): - public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, None) + if not cert_path and not public_key_file: + raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.") + # If user doesn't provide a public key, save key to the same folder as --file + keys_folder = None + if not public_key_file: + keys_folder = os.path.dirname(cert_path) + logger.warning("No public key provided. A new key pair will be created in the same directory as the certificate (%s). " + "Please delete keys once the certificate is no longer being used", + keys_folder) + public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) print(cert_file + "\n") diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 220d2422364..8dbd4cfed31 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -96,6 +96,7 @@ def get_ssh_cert_validity(cert_file): for line in info: if "Valid:" in line: return line.strip() + return None def write_ssh_config(config_path, resource_group, vm_name, overwrite, @@ -106,8 +107,11 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, path_to_delete = os.path.dirname(cert_file) if not delete_keys: path_to_delete = cert_file + validity = get_ssh_cert_validity(cert_file) + if validity: + validity_warning = f"The signed public key {cert_file} is {validity.lower()}" logger.warning("%s contains sensitive information, please delete it once you no longer need this config file. " - "The signed public key %s is %s", path_to_delete, cert_file, get_ssh_cert_validity(cert_file)) + "%s", path_to_delete, validity_warning) lines = [""] From 7453864fe5a3463ae7db628bdc00b350dc88689b Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 2 Nov 2021 13:35:21 -0400 Subject: [PATCH 09/26] Update History with ssh cert changes --- src/ssh/HISTORY.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 4849ac34bb2..7dc6c28ee0f 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -5,6 +5,8 @@ Release History * Delete all keys and certificates created during execution of az ssh vm. * Add --keys-destination-folder to az ssh config * By default, save keys created during az ssh config in a directory in the same location as --file +* Users no longer allowed to run az ssh cert with no parameters. +* When public key not provided to az ssh cert, new key pair is saved in the same folder as --file. 0.1.8 ----- From 0cc8db7c351480d22a899dd2615c79c369d6e00a Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 2 Nov 2021 13:55:46 -0400 Subject: [PATCH 10/26] Fixing some comments --- src/ssh/HISTORY.md | 10 +++++----- src/ssh/azext_ssh/custom.py | 3 ++- src/ssh/azext_ssh/ssh_utils.py | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 7dc6c28ee0f..e38799ccf53 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -2,11 +2,11 @@ Release History =============== 0.1.9 ----- -* Delete all keys and certificates created during execution of az ssh vm. -* Add --keys-destination-folder to az ssh config -* By default, save keys created during az ssh config in a directory in the same location as --file -* Users no longer allowed to run az ssh cert with no parameters. -* When public key not provided to az ssh cert, new key pair is saved in the same folder as --file. +* Delete all keys and certificates created during execution of ssh vm. +* Add --keys-destination-folder to ssh config +* By default, save keys created during ssh config in a directory in the same location as --file +* Users no longer allowed to run ssh cert with no parameters. +* When public key not provided to ssh cert, new key pair is saved in the same folder as --file. 0.1.8 ----- diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index d03fddbcd67..1c42ed1932b 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -151,9 +151,10 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre delete_keys = True if not credentials_folder: # az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails. - # az ssh cert: If user didn't provide public key, save it to temp folder credentials_folder = tempfile.mkdtemp(prefix="aadsshcert") else: + # az ssh config: Keys saved to the same folder as --file or to --keys-destination-folder. + # az ssh cert: Keys saved to the same folder as --file. if not os.path.isdir(credentials_folder): os.makedirs(credentials_folder) public_key_file = os.path.join(credentials_folder, "id_rsa.pub") diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 8dbd4cfed31..a5a59ab6b32 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -108,6 +108,7 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, if not delete_keys: path_to_delete = cert_file validity = get_ssh_cert_validity(cert_file) + validity_warning = "" if validity: validity_warning = f"The signed public key {cert_file} is {validity.lower()}" logger.warning("%s contains sensitive information, please delete it once you no longer need this config file. " From ca388d4f069dbb3f4e856f64f7ab33dbfebaf5d9 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 3 Nov 2021 12:59:52 -0400 Subject: [PATCH 11/26] Address review comments --- src/ssh/HISTORY.md | 6 +++--- src/ssh/azext_ssh/_params.py | 2 +- src/ssh/azext_ssh/custom.py | 17 +++++++++++------ src/ssh/azext_ssh/ssh_utils.py | 6 +++--- src/ssh/setup.py | 2 +- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index e38799ccf53..387866484cd 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,12 +1,12 @@ Release History =============== -0.1.9 +1.0.0 ----- * Delete all keys and certificates created during execution of ssh vm. * Add --keys-destination-folder to ssh config -* By default, save keys created during ssh config in a directory in the same location as --file +* Keys generated during ssh config are saved in az_ssh_config folder in the same directory as --file. * Users no longer allowed to run ssh cert with no parameters. -* When public key not provided to ssh cert, new key pair is saved in the same folder as --file. +* When --public-key-file/-f is not provided to ssh cert, generated public and private keys are saved in the same folder as --file. 0.1.8 ----- diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 3ab296aba0d..301bc023033 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -35,4 +35,4 @@ def load_arguments(self, _): c.argument('cert_path', options_list=['--file', '-f'], help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appened') c.argument('public_key_file', options_list=['--public-key-file', '-p'], - help='The RSA public key file path. If not provided, new key pair is created in the same directpry as --file') + help='The RSA public key file path. If not provided, generated key pair is stored in the same directory as --file') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 1c42ed1932b..b2c44822646 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -21,6 +21,7 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, port=None, ssh_args=None): + _assert_args(resource_group_name, vm_name, ssh_ip) credentials_folder = None op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args) _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, @@ -29,14 +30,19 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, credentials_folder=None): + _assert_args(resource_group_name, vm_name, ssh_ip) # If user provides their own key pair, certificate will be written in the same folder as public key. if (public_key_file or private_key_file) and credentials_folder: - raise azclierror.InvalidArgumentValueError("If providing --public-key-file/-p or --private-key-file/-i, --keys-destination-folder should not be used") + raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with --public-key-file/-p or --private-key-file/-i.") op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) # Default credential location if not credentials_folder: config_folder = os.path.dirname(config_path) - folder_name = resource_group_name + "-" + vm_name + if not os.path.isdir(config_folder): + raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} does not exist.") + folder_name = ssh_ip + if resource_group_name and vm_name: + folder_name = resource_group_name + "-" + vm_name credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call) @@ -45,12 +51,11 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= def ssh_cert(cmd, cert_path=None, public_key_file=None): if not cert_path and not public_key_file: raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.") - # If user doesn't provide a public key, save key to the same folder as --file + # If user doesn't provide a public key, save generated key pair to the same folder as --file keys_folder = None if not public_key_file: keys_folder = os.path.dirname(cert_path) - logger.warning("No public key provided. A new key pair will be created in the same directory as the certificate (%s). " - "Please delete keys once the certificate is no longer being used", + logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate is no longer being used.", keys_folder) public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) @@ -58,7 +63,6 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call): - _assert_args(resource_group, vm_name, ssh_ip) # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) @@ -148,6 +152,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre # If nothing is passed in create a directory with a ephemeral keypair if not public_key_file and not private_key_file: # We only want to delete the keys if the user hasn't providede their own keys + # Only ssh vm deletes generated keys. delete_keys = True if not credentials_folder: # az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails. diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index a5a59ab6b32..c4818ab5318 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -110,9 +110,9 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, validity = get_ssh_cert_validity(cert_file) validity_warning = "" if validity: - validity_warning = f"The signed public key {cert_file} is {validity.lower()}" - logger.warning("%s contains sensitive information, please delete it once you no longer need this config file. " - "%s", path_to_delete, validity_warning) + validity_warning = f" {validity.lower()}" + logger.warning("%s contains sensitive information%s, please delete it once you no longer need this config file. ", + path_to_delete, validity_warning) lines = [""] diff --git a/src/ssh/setup.py b/src/ssh/setup.py index 0e337c8b51a..cac0b4129a5 100644 --- a/src/ssh/setup.py +++ b/src/ssh/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages -VERSION = "0.1.9" +VERSION = "1.0.0" CLASSIFIERS = [ 'Development Status :: 4 - Beta', From 1bf432f15e9c0a2484618498bed46f43f01cd15c Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 3 Nov 2021 16:15:31 -0400 Subject: [PATCH 12/26] Fix unit tests --- src/ssh/azext_ssh/_params.py | 6 +- src/ssh/azext_ssh/custom.py | 33 ++++--- src/ssh/azext_ssh/file_utils.py | 2 +- src/ssh/azext_ssh/ssh_utils.py | 7 +- src/ssh/azext_ssh/tests/latest/test_custom.py | 90 ++++++++++++------- .../azext_ssh/tests/latest/test_ssh_utils.py | 41 ++++++--- 6 files changed, 116 insertions(+), 63 deletions(-) diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 301bc023033..5c0d1b8ec86 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -30,9 +30,9 @@ def load_arguments(self, _): c.argument('credentials_folder', options_list=['--keys-destination-folder'], help='Folder where new generated keys will be stored.') - with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appened') - c.argument('public_key_file', options_list=['--public-key-file', '-p'], - help='The RSA public key file path. If not provided, generated key pair is stored in the same directory as --file') + c.argument('public_key_file', options_list=['--public-key-file', '-p'], + help='The RSA public key file path. If not provided, ' + 'generated key pair is stored in the same directory as --file') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index b2c44822646..2e1f2f9f6d8 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -12,12 +12,12 @@ from knack import log from azure.cli.core import azclierror -logger = log.get_logger(__name__) - from . import ip_utils from . import rsa_parser from . import ssh_utils +logger = log.get_logger(__name__) + def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, port=None, ssh_args=None): @@ -29,23 +29,29 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, - public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, credentials_folder=None): + public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, + credentials_folder=None): _assert_args(resource_group_name, vm_name, ssh_ip) # If user provides their own key pair, certificate will be written in the same folder as public key. if (public_key_file or private_key_file) and credentials_folder: - raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with --public-key-file/-p or --private-key-file/-i.") + raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with " + "--public-key-file/-p or --private-key-file/-i.") op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) # Default credential location + if credentials_folder and not os.path.isdir(credentials_folder): + raise azclierror.InvalidArgumentValueError(f"--keys-destination-folder {credentials_folder} doesn't exist") if not credentials_folder: config_folder = os.path.dirname(config_path) if not os.path.isdir(config_folder): - raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} does not exist.") + raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} " + "does not exist.") folder_name = ssh_ip if resource_group_name and vm_name: folder_name = resource_group_name + "-" + vm_name credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call) + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, + credentials_folder, op_call) def ssh_cert(cmd, cert_path=None, public_key_file=None): @@ -55,14 +61,15 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): keys_folder = None if not public_key_file: keys_folder = os.path.dirname(cert_path) - logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate is no longer being used.", - keys_folder) + logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate " + "is no longer being used.", keys_folder) public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) print(cert_file + "\n") -def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call): +def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, + credentials_folder, op_call): # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) @@ -72,7 +79,9 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") - public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder) + public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, + private_key_file, + credentials_folder) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) op_call(ssh_ip, username, cert_file, private_key_file, delete_keys) @@ -152,14 +161,14 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre # If nothing is passed in create a directory with a ephemeral keypair if not public_key_file and not private_key_file: # We only want to delete the keys if the user hasn't providede their own keys - # Only ssh vm deletes generated keys. + # Only ssh vm deletes generated keys. delete_keys = True if not credentials_folder: # az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails. credentials_folder = tempfile.mkdtemp(prefix="aadsshcert") else: # az ssh config: Keys saved to the same folder as --file or to --keys-destination-folder. - # az ssh cert: Keys saved to the same folder as --file. + # az ssh cert: Keys saved to the same folder as --file. if not os.path.isdir(credentials_folder): os.makedirs(credentials_folder) public_key_file = os.path.join(credentials_folder, "id_rsa.pub") diff --git a/src/ssh/azext_ssh/file_utils.py b/src/ssh/azext_ssh/file_utils.py index 0cbf7be373f..b31927b9268 100644 --- a/src/ssh/azext_ssh/file_utils.py +++ b/src/ssh/azext_ssh/file_utils.py @@ -46,4 +46,4 @@ def delete_folder(dir_path, message, warning=False): if warning: logger.warning(message) else: - raise azclierror.FileOperationError(message + "Error: " + str(e)) from e \ No newline at end of file + raise azclierror.FileOperationError(message + "Error: " + str(e)) from e diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index c4818ab5318..86624556729 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -6,7 +6,6 @@ import platform import subprocess import time -import re import multiprocessing as mp from azext_ssh import file_utils @@ -28,14 +27,14 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi log_file = None if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - # If the user either provides his own client log file (-E) or + # If the user either provides his own client log file (-E) or # wants the client log messages to be printed to the console (-vvv/-vv/-v), # we should not use the log files to check for connection success. log_file_dir = os.path.dirname(cert_file) log_file_name = 'ssh_client_log_' + str(os.getpid()) log_file = os.path.join(log_file_dir, log_file_name) ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] - + command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list @@ -199,7 +198,7 @@ def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): elif wait: # if we are not checking the logs, but still want to wait for connection before deleting files time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) - + # TO DO: Once arc changes are merged, delete relay information as well if delete_keys and private_key: public_key = private_key + '.pub' diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 63482a70e28..0be930fa335 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -8,51 +8,59 @@ from unittest import mock import unittest + from azext_ssh import custom class SshCustomCommandTest(unittest.TestCase): @mock.patch('azext_ssh.custom._do_ssh_op') - @mock.patch('azext_ssh.custom.ssh_utils') - def test_ssh_vm(self, mock_ssh_utils, mock_do_op): + @mock.patch('azext_ssh.custom._assert_args') + def test_ssh_vm(self, mock_assert, mock_do_op): cmd = mock.Mock() custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False) + mock_assert.assert_called_once_with("rg", "vm", "ip") mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", False, mock.ANY) - + cmd, "rg", "vm", "ip", "public", "private", False, None, mock.ANY) + @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('azext_ssh.ssh_utils.write_ssh_config') - def test_ssh_config(self, mock_ssh_utils, mock_do_op): + @mock.patch('azext_ssh.custom._assert_args') + @mock.patch('os.path.isdir') + @mock.patch('os.path.dirname') + @mock.patch('os.path.join') + def test_ssh_config(self, mock_join, mock_dirname, mock_isdir, mock_assert, mock_ssh_utils, mock_do_op): cmd = mock.Mock() + mock_dirname.return_value = "configdir" + mock_isdir.return_value = True + mock_join.side_effect = ['az_ssh_config/rg-vm', 'path/to/az_ssh_config/rg-vm'] - def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, op_call): - op_call(ssh_ip, "username", "cert_file", private_key_file) + def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call): + op_call(ssh_ip, "username", "cert_file", private_key_file, False) mock_do_op.side_effect = do_op_side_effect - custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False) - - mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "ip", "username", "cert_file", "private") + custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False, None) + mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "ip", "username", "cert_file", "private", False) + mock_assert.assert_called_once_with("rg", "vm", "ip") mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", False, mock.ANY) - + cmd, "rg", "vm", "ip", "public", "private", False, 'path/to/az_ssh_config/rg-vm', mock.ANY) + @mock.patch('azext_ssh.ssh_utils.get_ssh_cert_principals') @mock.patch('os.path.join') - @mock.patch('azext_ssh.custom._assert_args') @mock.patch('azext_ssh.custom._check_or_create_public_private_files') @mock.patch('azext_ssh.ip_utils.get_ssh_ip') @mock.patch('azext_ssh.custom._get_modulus_exponent') @mock.patch('azure.cli.core._profile.Profile') @mock.patch('azext_ssh.custom._write_cert_file') def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock_ip, - mock_check_files, mock_assert, mock_join, mock_principal): + mock_check_files, mock_join, mock_principal): cmd = mock.Mock() cmd.cli_ctx = mock.Mock() cmd.cli_ctx.cloud = mock.Mock() cmd.cli_ctx.cloud.name = "azurecloud" mock_op = mock.Mock() - mock_check_files.return_value = "public", "private" + mock_check_files.return_value = "public", "private", False mock_principal.return_value = ["username"] mock_get_mod_exp.return_value = "modulus", "exponent" profile = mock_ssh_creds.return_value @@ -60,33 +68,29 @@ def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock profile.get_msal_token.return_value = "username", "certificate" mock_join.return_value = "public-aadcert.pub" - custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, mock_op) + custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, "cred/folder", mock_op) - mock_assert.assert_called_once_with(None, None, "1.2.3.4") - mock_check_files.assert_called_once_with("publicfile", "privatefile") + mock_check_files.assert_called_once_with("publicfile", "privatefile", "cred/folder") mock_ip.assert_not_called() mock_get_mod_exp.assert_called_once_with("public") mock_write_cert.assert_called_once_with("certificate", "public-aadcert.pub") mock_op.assert_called_once_with( - "1.2.3.4", "username", "public-aadcert.pub", "private") - - @mock.patch('azext_ssh.custom._assert_args') + "1.2.3.4", "username", "public-aadcert.pub", "private", False) + @mock.patch('azext_ssh.custom._check_or_create_public_private_files') @mock.patch('azext_ssh.ip_utils.get_ssh_ip') @mock.patch('azext_ssh.custom._get_modulus_exponent') - def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_files, mock_assert): + def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_files): cmd = mock.Mock() mock_op = mock.Mock() - mock_check_files.return_value = "public", "private" mock_get_mod_exp.return_value = "modulus", "exponent" mock_ip.return_value = None self.assertRaises( azclierror.ResourceNotFoundError, custom._do_ssh_op, cmd, "rg", "vm", None, - "publicfile", "privatefile", False, mock_op) + "publicfile", "privatefile", False, "cred/folder", mock_op) - mock_assert.assert_called_once_with("rg", "vm", None) - mock_check_files.assert_called_once_with("publicfile", "privatefile") + mock_check_files.assert_not_called() mock_ip.assert_called_once_with(cmd, "rg", "vm", False) def test_assert_args_no_ip_or_vm(self): @@ -99,7 +103,7 @@ def test_assert_args_vm_rg_mismatch(self): def test_assert_args_ip_with_vm_or_rg(self): self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, "vm", "ip") self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", "vm", "ip") - + @mock.patch('azext_ssh.ssh_utils.create_ssh_keyfile') @mock.patch('tempfile.mkdtemp') @mock.patch('os.path.isfile') @@ -109,10 +113,11 @@ def test_check_or_create_public_private_files_defaults(self, mock_join, mock_isf mock_temp.return_value = "/tmp/aadtemp" mock_join.side_effect = ['/tmp/aadtemp/id_rsa.pub', '/tmp/aadtemp/id_rsa'] - public, private = custom._check_or_create_public_private_files(None, None) + public, private, delete_key = custom._check_or_create_public_private_files(None, None, None) self.assertEqual('/tmp/aadtemp/id_rsa.pub', public) self.assertEqual('/tmp/aadtemp/id_rsa', private) + self.assertEqual(True, delete_key) mock_join.assert_has_calls([ mock.call("/tmp/aadtemp", "id_rsa.pub"), mock.call("/tmp/aadtemp", "id_rsa") @@ -125,12 +130,37 @@ def test_check_or_create_public_private_files_defaults(self, mock_join, mock_isf mock.call('/tmp/aadtemp/id_rsa') ]) + @mock.patch('azext_ssh.ssh_utils.create_ssh_keyfile') + @mock.patch('os.path.isdir') + @mock.patch('os.path.isfile') + @mock.patch('os.path.join') + def test_check_or_create_public_private_files_defaults_with_cred_folder(self,mock_join, mock_isfile, mock_isdir, mock_create): + mock_isfile.return_value = True + mock_isdir.return_value = True + mock_join.side_effect = ['/cred/folder/id_rsa.pub', '/cred/folder/id_rsa'] + public, private, delete_key = custom._check_or_create_public_private_files(None, None, '/cred/folder') + self.assertEqual('/cred/folder/id_rsa.pub', public) + self.assertEqual('/cred/folder/id_rsa', private) + self.assertEqual(True, delete_key) + mock_join.assert_has_calls([ + mock.call("/cred/folder", "id_rsa.pub"), + mock.call("/cred/folder", "id_rsa") + ]) + mock_isfile.assert_has_calls([ + mock.call('/cred/folder/id_rsa.pub'), + mock.call('/cred/folder/id_rsa') + ]) + mock_create.assert_has_calls([ + mock.call('/cred/folder/id_rsa') + ]) + + @mock.patch('os.path.isfile') @mock.patch('os.path.join') def test_check_or_create_public_private_files_no_public(self, mock_join, mock_isfile): mock_isfile.side_effect = [False] self.assertRaises( - azclierror.FileOperationError, custom._check_or_create_public_private_files, "public", None) + azclierror.FileOperationError, custom._check_or_create_public_private_files, "public", None, None) mock_isfile.assert_called_once_with("public") @@ -140,7 +170,7 @@ def test_check_or_create_public_private_files_no_private(self, mock_join, mock_i mock_isfile.side_effect = [True, False] self.assertRaises( - azclierror.FileOperationError, custom._check_or_create_public_private_files, "public", "private") + azclierror.FileOperationError, custom._check_or_create_public_private_files, "public", "private", None) mock_join.assert_not_called() mock_isfile.assert_has_calls([ diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index 5171dd662fc..1be7eb4481d 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -10,19 +10,22 @@ from azext_ssh import ssh_utils - class SSHUtilsTests(unittest.TestCase): + + @mock.patch('os.path.join') @mock.patch.object(ssh_utils, '_get_ssh_path') @mock.patch.object(ssh_utils, '_get_host') @mock.patch.object(ssh_utils, '_build_args') @mock.patch('subprocess.call') - def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path): + def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path, mock_join): mock_path.return_value = "ssh" mock_host.return_value = "user@ip" mock_build.return_value = ['-i', 'file', '-o', 'option'] - expected_command = ["ssh", "user@ip", "-i", "file", "-o", "option"] + mock_join.return_value = "/log/file/path" - ssh_utils.start_ssh_connection(None, None, "ip", "user", "cert", "private") + expected_command = ["ssh", "user@ip", "-i", "file", "-o", "option", "-E", "/log/file/path", "-v"] + + ssh_utils.start_ssh_connection(None, None, "ip", "user", "cert", "private", True) mock_path.assert_called_once_with() mock_host.assert_called_once_with("user", "ip") @@ -36,15 +39,18 @@ def test_start_ssh_connection_with_args(self, mock_call, mock_host, mock_path): mock_path.return_value = "ssh" mock_host.return_value = "user@ip" - expected_command = ["ssh", "user@ip", "-i", "private", "-o", "CertificateFile=cert", "-p", "2222", "--thing"] + expected_command = ["ssh", "user@ip", "-i", "private", "-o", "CertificateFile=cert", "-p", "2222", "--thing", "-vv"] - ssh_utils.start_ssh_connection("2222", ["--thing"], "ip", "user", "cert", "private") + ssh_utils.start_ssh_connection("2222", ["--thing", "-vv"], "ip", "user", "cert", "private", True) mock_path.assert_called_once_with() mock_host.assert_called_once_with("user", "ip") mock_call.assert_called_once_with(expected_command, shell=platform.system() == 'Windows') - def test_write_ssh_config_ip_and_vm(self): + + @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') + def test_write_ssh_config_ip_and_vm(self, mock_validity): + mock_validity.return_value = None expected_lines = [ "", "Host rg-vm", @@ -62,13 +68,14 @@ def test_write_ssh_config_ip_and_vm(self): mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", True, "1.2.3.4", "username", "cert", "privatekey" + "path/to/file", "rg", "vm", True, "1.2.3.4", "username", "cert", "privatekey", False ) - + mock_validity.assert_called_once_with("cert") mock_open.assert_called_once_with("path/to/file", "w") mock_file.write.assert_called_once_with('\n'.join(expected_lines)) - def test_write_ssh_config_append(self): + @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') + def test_write_ssh_config_append(self, mock_validity): expected_lines = [ "", "Host rg-vm", @@ -82,17 +89,22 @@ def test_write_ssh_config_append(self): "\tIdentityFile privatekey" ] + mock_validity.return_value = None + with mock.patch('builtins.open') as mock_open: mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", False, "1.2.3.4", "username", "cert", "privatekey" + "path/to/file", "rg", "vm", False, "1.2.3.4", "username", "cert", "privatekey", False ) + mock_validity.assert_called_once_with("cert") + mock_open.assert_called_once_with("path/to/file", "a") mock_file.write.assert_called_once_with('\n'.join(expected_lines)) - def test_write_ssh_config_ip_only(self): + @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') + def test_write_ssh_config_ip_only(self, mock_validity): expected_lines = [ "", "Host 1.2.3.4", @@ -100,14 +112,17 @@ def test_write_ssh_config_ip_only(self): "\tCertificateFile cert", "\tIdentityFile privatekey" ] + mock_validity.return_value = None with mock.patch('builtins.open') as mock_open: mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", None, None, True, "1.2.3.4", "username", "cert", "privatekey" + "path/to/file", None, None, True, "1.2.3.4", "username", "cert", "privatekey", False ) + mock_validity.assert_called_once_with("cert") + mock_open.assert_called_once_with("path/to/file", "w") mock_file.write.assert_called_once_with('\n'.join(expected_lines)) From 9a267c74a3599011ff998f0fa3ba008e18e5264c Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 3 Nov 2021 17:33:26 -0400 Subject: [PATCH 13/26] A few adjustments after tests --- src/ssh/azext_ssh/custom.py | 2 ++ src/ssh/azext_ssh/ssh_utils.py | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 2e1f2f9f6d8..160839077a9 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -57,6 +57,8 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= def ssh_cert(cmd, cert_path=None, public_key_file=None): if not cert_path and not public_key_file: raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.") + if cert_path and not os.path.isdir(os.path.dirname(cert_path)): + raise azclierror.InvalidArgumentValueError(f"{os.path.dirname(cert_path)} folder doesn't exist") # If user doesn't provide a public key, save generated key pair to the same folder as --file keys_folder = None if not public_key_file: diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 86624556729..84fa9fe487f 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -104,14 +104,16 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, # Warn users to delete credentials once config file is no longer being used. # If user provided keys, only ask them to delete the certificate. path_to_delete = os.path.dirname(cert_file) + items_to_delete = " (id_rsa, id_rsa.pub, id_rsa.pub-aadcert.pub)" if not delete_keys: path_to_delete = cert_file + items_to_delete = "" validity = get_ssh_cert_validity(cert_file) validity_warning = "" if validity: validity_warning = f" {validity.lower()}" - logger.warning("%s contains sensitive information%s, please delete it once you no longer need this config file. ", - path_to_delete, validity_warning) + logger.warning("%s contains sensitive information%s%s, please delete it once you no longer need this config file. ", + path_to_delete, items_to_delete, validity_warning) lines = [""] @@ -181,8 +183,8 @@ def _build_args(cert_file, private_key_file, port): def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): # if there is a log file, use it to check for the connection success + print(os.getpid()) if log_file: - time.sleep(500) t0 = time.time() match = False while (time.time() - t0) < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS and not match: From eafabb8c2f023969cfdecc6211b362e8034f30ce Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 3 Nov 2021 17:34:35 -0400 Subject: [PATCH 14/26] Remove debug print statements --- src/ssh/azext_ssh/ssh_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 84fa9fe487f..a8d40d1f17b 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -183,7 +183,6 @@ def _build_args(cert_file, private_key_file, port): def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): # if there is a log file, use it to check for the connection success - print(os.getpid()) if log_file: t0 = time.time() match = False From 24068f1929d383c082427c99b6cddd763f532437 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 3 Nov 2021 18:44:12 -0400 Subject: [PATCH 15/26] Fix typos --- src/ssh/azext_ssh/custom.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 160839077a9..070ef37370a 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -160,9 +160,9 @@ def _assert_args(resource_group, vm_name, ssh_ip): def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder): delete_keys = False - # If nothing is passed in create a directory with a ephemeral keypair + # If nothing is passed, then create a directory with a ephemeral keypair if not public_key_file and not private_key_file: - # We only want to delete the keys if the user hasn't providede their own keys + # We only want to delete the keys if the user hasn't provided their own keys # Only ssh vm deletes generated keys. delete_keys = True if not credentials_folder: From 23790c627a236fdae852cd6d2c075c0337fc02ad Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 3 Nov 2021 19:03:20 -0400 Subject: [PATCH 16/26] Addessing comments --- src/ssh/azext_ssh/custom.py | 2 -- src/ssh/azext_ssh/ssh_utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 070ef37370a..4da5be24651 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -38,8 +38,6 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= "--public-key-file/-p or --private-key-file/-i.") op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) # Default credential location - if credentials_folder and not os.path.isdir(credentials_folder): - raise azclierror.InvalidArgumentValueError(f"--keys-destination-folder {credentials_folder} doesn't exist") if not credentials_folder: config_folder = os.path.dirname(config_path) if not os.path.isdir(config_folder): diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index a8d40d1f17b..a90319b4392 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -112,7 +112,7 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, validity_warning = "" if validity: validity_warning = f" {validity.lower()}" - logger.warning("%s contains sensitive information%s%s, please delete it once you no longer need this config file. ", + logger.warning("%s contains sensitive information%s%s\nPlease delete it once you no longer need this config file. ", path_to_delete, items_to_delete, validity_warning) lines = [""] From 910f1acf767e38ebe3b64ea21772d0fcd717ad62 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Thu, 4 Nov 2021 12:31:10 -0400 Subject: [PATCH 17/26] Added --keys-dest-folder as an alternative to --keys-destination folder to avoid linter error --- src/ssh/azext_ssh/_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 5c0d1b8ec86..d8f3f00fb2b 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -27,7 +27,7 @@ def load_arguments(self, _): help='Will use a private IP if available. By default only public IPs are used.') c.argument('overwrite', action='store_true', options_list=['--overwrite'], help='Overwrites the config file if this flag is set') - c.argument('credentials_folder', options_list=['--keys-destination-folder'], + c.argument('credentials_folder', options_list=['--keys-destination-folder', '--keys-dest-folder'], help='Folder where new generated keys will be stored.') with self.argument_context('ssh cert') as c: From ec68a95c6fbac0526fcebf16b44f7edc715d95a5 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Thu, 4 Nov 2021 12:57:03 -0400 Subject: [PATCH 18/26] Dummy commit just to trigger CI --- src/ssh/azext_ssh/_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index d8f3f00fb2b..c32cf32d70e 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -35,4 +35,4 @@ def load_arguments(self, _): help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appened') c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path. If not provided, ' - 'generated key pair is stored in the same directory as --file') + 'generated key pair is stored in the same directory as --file.') From d8e444f7316c05ede0d13977b2f29cca094a67fe Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 8 Nov 2021 14:04:36 -0500 Subject: [PATCH 19/26] Initial changes to allow loging into local users --- src/ssh/azext_ssh/_help.py | 12 +++- src/ssh/azext_ssh/_params.py | 8 +++ src/ssh/azext_ssh/custom.py | 40 ++++++++---- src/ssh/azext_ssh/ssh_utils.py | 108 ++++++++++++++++++--------------- 4 files changed, 105 insertions(+), 63 deletions(-) diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index b049a3a3c7f..75a00fb14c3 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -12,7 +12,8 @@ helps['ssh vm'] = """ type: command - short-summary: SSH into Azure VMs using an ssh certificate + short-summary: SSH into Azure VMs + long-summary: Users can login using AAD issued certificates or using local user credentials. We recommend login using AAD issued certificates as azure automatically rotate SSH CA keys. To SSH as a local user in the target machine, you must provide the local user name using the --local-user argument. examples: - name: Give a resource group and VM to SSH to text: | @@ -27,6 +28,15 @@ - name: Using additional ssh arguments text: | az ssh vm --ip 1.2.3.4 -- -A -o ForwardX11=yes + - name: Give a local user name to SSH using local user credentials on the target machine using certificate based authentication. + text: | + az ssh vm --local-user username --ip 1.2.3.4 --certificate-file cert.pub --private-key key + - name: Give a local user name to SSH using local user credentials on the target machine using key based authentication. + text: | + az ssh vm --local-user username --resource-group myResourceGroup --vm-name myVM --private-key-file key + - name: Give a local user name to SSH using local user credentials on the target machine using password based authentication. + text: | + az ssh vm --local-user username --ip 1.2.3.4 """ helps['ssh config'] = """ diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index c32cf32d70e..608465601de 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -14,6 +14,10 @@ def load_arguments(self, _): c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path') c.argument('use_private_ip', options_list=['--prefer-private-ip'], help='Will prefer private IP. Requires connectivity to the private IP.') + c.argument('local_user', options_list=['--local-user'], + help='The username for a local user') + c.argument('cert_file', options_list=['--certificate-file', '-c'], + help='Path to a certificate file used for authentication when using local user credentials.') c.argument('port', options_list=['--port'], help='SSH port') c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') @@ -25,10 +29,14 @@ def load_arguments(self, _): c.argument('private_key_file', options_list=['--private-key-file', '-i'], help='The RSA private key file path') c.argument('use_private_ip', options_list=['--prefer-private-ip'], help='Will use a private IP if available. By default only public IPs are used.') + c.argument('local_user', options_list=['--local-user'], + help='The username for a local user') c.argument('overwrite', action='store_true', options_list=['--overwrite'], help='Overwrites the config file if this flag is set') c.argument('credentials_folder', options_list=['--keys-destination-folder', '--keys-dest-folder'], help='Folder where new generated keys will be stored.') + c.argument('cert_file', options_list=['--certificate-file', '-c'], help='Path to certificate file') + c.argument('port', options_list=['--port'], help='SSH port') with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 4da5be24651..1d6ac363408 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -20,23 +20,24 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, - private_key_file=None, use_private_ip=False, port=None, ssh_args=None): - _assert_args(resource_group_name, vm_name, ssh_ip) + private_key_file=None, use_private_ip=False, local_user=None, cert_file=None, port=None, + ssh_args=None): + _assert_args(resource_group_name, vm_name, ssh_ip, cert_file, local_user) credentials_folder = None op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, - public_key_file, private_key_file, use_private_ip, credentials_folder, op_call) + _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, + local_user, cert_file, credentials_folder, op_call) def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, - credentials_folder=None): - _assert_args(resource_group_name, vm_name, ssh_ip) + local_user=None, cert_file=None, port=None, credentials_folder=None): + _assert_args(resource_group_name, vm_name, ssh_ip, cert_file, local_user) # If user provides their own key pair, certificate will be written in the same folder as public key. if (public_key_file or private_key_file) and credentials_folder: raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with " "--public-key-file/-p or --private-key-file/-i.") - op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite) + op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite, port) # Default credential location if not credentials_folder: config_folder = os.path.dirname(config_path) @@ -49,7 +50,7 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, - credentials_folder, op_call) + local_user, cert_file, credentials_folder, op_call) def ssh_cert(cmd, cert_path=None, public_key_file=None): @@ -69,7 +70,7 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, - credentials_folder, op_call): + username, cert_file, credentials_folder, op_call): # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) @@ -79,11 +80,17 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") - public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, + # If user ptovides local user, no credentials should be deleted. + delete_keys = False + delete_cert = False + # If user provides a local user, use the provided credentials for authentication + if not username: + delete_cert = True + public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder) - cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) - op_call(ssh_ip, username, cert_file, private_key_file, delete_keys) + cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) + op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert) def _get_and_write_certificate(cmd, public_key_file, cert_file): @@ -142,7 +149,7 @@ def _prepare_jwk_data(public_key_file): return data -def _assert_args(resource_group, vm_name, ssh_ip): +def _assert_args(resource_group, vm_name, ssh_ip, cert_file, username): if not (resource_group or vm_name or ssh_ip): raise azclierror.RequiredArgumentMissingError( "The VM must be specified by --ip or --resource-group and --vm-name/--name") @@ -154,6 +161,13 @@ def _assert_args(resource_group, vm_name, ssh_ip): if ssh_ip and (vm_name or resource_group): raise azclierror.MutuallyExclusiveArgumentError( "--ip cannot be used with --resource-group or --vm-name/--name") + + if cert_file and not username: + raise azclierror.MutuallyExclusiveArgumentError( + "To authenticate with a certificate you need to provide a --local-user") + + if cert_file and not os.path.isfile(cert_file): + raise azclierror.FileOperationError(f"Certificate file {cert_file} not found") def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder): diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index a90319b4392..e60be3ab2e5 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -19,47 +19,48 @@ CLEANUP_AWAIT_TERMINATION_IN_SECONDS = 30 -def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_file, delete_keys): +def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_file, delete_keys, delete_cert): ssh_arg_list = [] if ssh_args: ssh_arg_list = ssh_args - + log_file = None - if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - # If the user either provides his own client log file (-E) or - # wants the client log messages to be printed to the console (-vvv/-vv/-v), - # we should not use the log files to check for connection success. - log_file_dir = os.path.dirname(cert_file) - log_file_name = 'ssh_client_log_' + str(os.getpid()) - log_file = os.path.join(log_file_dir, log_file_name) - ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] + if delete_keys or delete_cert: + if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): + # If the user either provides his own client log file (-E) or + # wants the client log messages to be printed to the console (-vvv/-vv/-v), + # we should not use the log files to check for connection success. + log_file_dir = os.path.dirname(cert_file) + log_file_name = 'ssh_client_log_' + str(os.getpid()) + log_file = os.path.join(log_file_dir, log_file_name) + ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] + # Create a new process that will wait until the connection is established and then delete keys. + cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, delete_cert, cert_file, private_key_file, + log_file, True)) + cleanup_process.start() command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list - # Create a new process that will wait until the connection is established and then delete keys. - cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, cert_file, private_key_file, - log_file, True)) - cleanup_process.start() - logger.debug("Running ssh command %s", ' '.join(command)) subprocess.call(command, shell=platform.system() == 'Windows') - if cleanup_process.is_alive(): - cleanup_process.terminate() - # wait for process to terminate - t0 = time.time() - while cleanup_process.is_alive() and (time.time() - t0) < CLEANUP_AWAIT_TERMINATION_IN_SECONDS: - time.sleep(1) + if delete_keys or delete_cert: + if cleanup_process.is_alive(): + cleanup_process.terminate() + # wait for process to terminate + t0 = time.time() + while cleanup_process.is_alive() and (time.time() - t0) < CLEANUP_AWAIT_TERMINATION_IN_SECONDS: + time.sleep(1) - # Make sure all files have been properly removed. - _do_cleanup(delete_keys, cert_file, private_key_file) - if log_file: - file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) - if delete_keys: - temp_dir = os.path.dirname(cert_file) - file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) + # Make sure all files have been properly removed. + _do_cleanup(delete_keys, cert_file, private_key_file) + if log_file: + file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) + if delete_keys: + temp_dir = os.path.dirname(cert_file) + file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) def create_ssh_keyfile(private_key_file): @@ -98,22 +99,23 @@ def get_ssh_cert_validity(cert_file): return None -def write_ssh_config(config_path, resource_group, vm_name, overwrite, - ip, username, cert_file, private_key_file, delete_keys): - - # Warn users to delete credentials once config file is no longer being used. - # If user provided keys, only ask them to delete the certificate. - path_to_delete = os.path.dirname(cert_file) - items_to_delete = " (id_rsa, id_rsa.pub, id_rsa.pub-aadcert.pub)" - if not delete_keys: - path_to_delete = cert_file - items_to_delete = "" - validity = get_ssh_cert_validity(cert_file) - validity_warning = "" - if validity: - validity_warning = f" {validity.lower()}" - logger.warning("%s contains sensitive information%s%s\nPlease delete it once you no longer need this config file. ", - path_to_delete, items_to_delete, validity_warning) +def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, + ip, username, cert_file, private_key_file, delete_keys, delete_cert): + + if delete_keys or delete_cert: + # Warn users to delete credentials once config file is no longer being used. + # If user provided keys, only ask them to delete the certificate. + path_to_delete = os.path.dirname(cert_file) + items_to_delete = " (id_rsa, id_rsa.pub, id_rsa.pub-aadcert.pub)" + if not delete_keys: + path_to_delete = cert_file + items_to_delete = "" + validity = get_ssh_cert_validity(cert_file) + validity_warning = "" + if validity: + validity_warning = f" {validity.lower()}" + logger.warning("%s contains sensitive information%s%s\nPlease delete it once you no longer need this config file. ", + path_to_delete, items_to_delete, validity_warning) lines = [""] @@ -121,9 +123,12 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, lines.append("Host " + resource_group + "-" + vm_name) lines.append("\tUser " + username) lines.append("\tHostName " + ip) - lines.append("\tCertificateFile " + cert_file) + if cert_file: + lines.append("\tCertificateFile " + cert_file) if private_key_file: lines.append("\tIdentityFile " + private_key_file) + if port: + lines.append("\tPort " + port) # default to all hosts for config if not ip: @@ -131,9 +136,12 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, lines.append("Host " + ip) lines.append("\tUser " + username) - lines.append("\tCertificateFile " + cert_file) + if cert_file: + lines.append("\tCertificateFile " + cert_file) if private_key_file: lines.append("\tIdentityFile " + private_key_file) + if port: + lines.append("\tPort " + port) if overwrite: mode = 'w' @@ -173,15 +181,17 @@ def _get_host(username, ip): def _build_args(cert_file, private_key_file, port): private_key = [] port_arg = [] + certificate = [] if private_key_file: private_key = ["-i", private_key_file] if port: port_arg = ["-p", port] - certificate = ["-o", "CertificateFile=" + cert_file] + if cert_file: + certificate = ["-o", "CertificateFile=" + cert_file] return private_key + certificate + port_arg -def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): +def _do_cleanup(delete_keys, delete_cert, cert_file, private_key, log_file=None, wait=False): # if there is a log file, use it to check for the connection success if log_file: t0 = time.time() @@ -205,5 +215,5 @@ def _do_cleanup(delete_keys, cert_file, private_key, log_file=None, wait=False): public_key = private_key + '.pub' file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) - if cert_file: + if delete_cert and cert_file: file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) From 2ddd5d0dce090a63090bd602452f35e0bb0c0d43 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 8 Nov 2021 14:42:42 -0500 Subject: [PATCH 20/26] Fix bugs --- src/ssh/azext_ssh/ssh_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index e60be3ab2e5..b1b47a34517 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -25,6 +25,8 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi if ssh_args: ssh_arg_list = ssh_args + print(cert_file) + log_file = None if delete_keys or delete_cert: if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): @@ -55,7 +57,7 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi time.sleep(1) # Make sure all files have been properly removed. - _do_cleanup(delete_keys, cert_file, private_key_file) + _do_cleanup(delete_keys, delete_cert, cert_file, private_key_file) if log_file: file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) if delete_keys: From 57f56227855983b9e0c3d2af423e629d2efcbdd6 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 8 Nov 2021 17:26:22 -0500 Subject: [PATCH 21/26] Address review comments --- src/ssh/azext_ssh/_help.py | 13 +++++++++++-- src/ssh/azext_ssh/custom.py | 3 ++- src/ssh/azext_ssh/ssh_utils.py | 2 -- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index 75a00fb14c3..7b7b524751a 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -12,28 +12,34 @@ helps['ssh vm'] = """ type: command - short-summary: SSH into Azure VMs - long-summary: Users can login using AAD issued certificates or using local user credentials. We recommend login using AAD issued certificates as azure automatically rotate SSH CA keys. To SSH as a local user in the target machine, you must provide the local user name using the --local-user argument. + short-summary: SSH into Azure VMs using an ssh certificate + long-summary: Users can login using AAD issued certificates or using local user credentials. We recommend login using AAD issued certificates. To SSH as a local user in the target machine, you must provide the local user name using the --local-user argument. examples: - name: Give a resource group and VM to SSH to text: | az ssh vm --resource-group myResourceGroup --vm-name myVm + - name: Give the public IP (or hostname) of a VM to SSH to text: | az ssh vm --ip 1.2.3.4 az ssh vm --hostname example.com + - name: Using a custom private key file text: | az ssh vm --ip 1.2.3.4 --private-key-file key --public-key-file key.pub + - name: Using additional ssh arguments text: | az ssh vm --ip 1.2.3.4 -- -A -o ForwardX11=yes + - name: Give a local user name to SSH using local user credentials on the target machine using certificate based authentication. text: | az ssh vm --local-user username --ip 1.2.3.4 --certificate-file cert.pub --private-key key + - name: Give a local user name to SSH using local user credentials on the target machine using key based authentication. text: | az ssh vm --local-user username --resource-group myResourceGroup --vm-name myVM --private-key-file key + - name: Give a local user name to SSH using local user credentials on the target machine using password based authentication. text: | az ssh vm --local-user username --ip 1.2.3.4 @@ -47,16 +53,19 @@ - name: Give a resource group and VM for which to create a config, and save in a local file text: | az ssh config --resource-group myResourceGroup --vm-name myVm --file ./sshconfig + - name: Give the public IP (or hostname) of a VM for which to create a config and then ssh text: | az ssh config --ip 1.2.3.4 --file ./sshconfig ssh -F ./sshconfig 1.2.3.4 + - name: Create a generic config for use with any host text: | #Bash az ssh config --ip \\* --file ./sshconfig #PowerShell az ssh config --ip * --file ./sshconfig + - name: Examples with other software text: | #Bash diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 1d6ac363408..bd77c82d0f7 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -80,7 +80,7 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") - # If user ptovides local user, no credentials should be deleted. + # If user provides local user, no credentials should be deleted. delete_keys = False delete_cert = False # If user provides a local user, use the provided credentials for authentication @@ -90,6 +90,7 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke private_key_file, credentials_folder) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) + op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index b1b47a34517..d1eaf816db1 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -24,8 +24,6 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi ssh_arg_list = [] if ssh_args: ssh_arg_list = ssh_args - - print(cert_file) log_file = None if delete_keys or delete_cert: From 3a9c11a3eba9f85c95ebce0173b475de04656430 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 8 Nov 2021 18:45:19 -0500 Subject: [PATCH 22/26] Fix unit tests and style tests --- src/ssh/azext_ssh/_help.py | 18 +++--- src/ssh/azext_ssh/_params.py | 2 +- src/ssh/azext_ssh/custom.py | 11 ++-- src/ssh/azext_ssh/ssh_utils.py | 7 ++- src/ssh/azext_ssh/tests/latest/test_custom.py | 58 +++++++++++++------ .../azext_ssh/tests/latest/test_ssh_utils.py | 20 ++++--- 6 files changed, 71 insertions(+), 45 deletions(-) diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index 7b7b524751a..7eb37c94e86 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -18,28 +18,28 @@ - name: Give a resource group and VM to SSH to text: | az ssh vm --resource-group myResourceGroup --vm-name myVm - + - name: Give the public IP (or hostname) of a VM to SSH to text: | az ssh vm --ip 1.2.3.4 az ssh vm --hostname example.com - + - name: Using a custom private key file text: | az ssh vm --ip 1.2.3.4 --private-key-file key --public-key-file key.pub - + - name: Using additional ssh arguments text: | az ssh vm --ip 1.2.3.4 -- -A -o ForwardX11=yes - + - name: Give a local user name to SSH using local user credentials on the target machine using certificate based authentication. text: | az ssh vm --local-user username --ip 1.2.3.4 --certificate-file cert.pub --private-key key - + - name: Give a local user name to SSH using local user credentials on the target machine using key based authentication. text: | az ssh vm --local-user username --resource-group myResourceGroup --vm-name myVM --private-key-file key - + - name: Give a local user name to SSH using local user credentials on the target machine using password based authentication. text: | az ssh vm --local-user username --ip 1.2.3.4 @@ -53,19 +53,19 @@ - name: Give a resource group and VM for which to create a config, and save in a local file text: | az ssh config --resource-group myResourceGroup --vm-name myVm --file ./sshconfig - + - name: Give the public IP (or hostname) of a VM for which to create a config and then ssh text: | az ssh config --ip 1.2.3.4 --file ./sshconfig ssh -F ./sshconfig 1.2.3.4 - + - name: Create a generic config for use with any host text: | #Bash az ssh config --ip \\* --file ./sshconfig #PowerShell az ssh config --ip * --file ./sshconfig - + - name: Examples with other software text: | #Bash diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 608465601de..2a1a5b33ce9 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -17,7 +17,7 @@ def load_arguments(self, _): c.argument('local_user', options_list=['--local-user'], help='The username for a local user') c.argument('cert_file', options_list=['--certificate-file', '-c'], - help='Path to a certificate file used for authentication when using local user credentials.') + help='Path to a certificate file used for authentication when using local user credentials.') c.argument('port', options_list=['--port'], help='SSH port') c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index bd77c82d0f7..1fd295d2c3a 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -37,6 +37,7 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= if (public_key_file or private_key_file) and credentials_folder: raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with " "--public-key-file/-p or --private-key-file/-i.") + op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite, port) # Default credential location if not credentials_folder: @@ -70,7 +71,7 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, - username, cert_file, credentials_folder, op_call): + username, cert_file, credentials_folder, op_call): # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) @@ -87,10 +88,10 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke if not username: delete_cert = True public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, - private_key_file, - credentials_folder) + private_key_file, + credentials_folder) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) - + op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert) @@ -162,7 +163,7 @@ def _assert_args(resource_group, vm_name, ssh_ip, cert_file, username): if ssh_ip and (vm_name or resource_group): raise azclierror.MutuallyExclusiveArgumentError( "--ip cannot be used with --resource-group or --vm-name/--name") - + if cert_file and not username: raise azclierror.MutuallyExclusiveArgumentError( "To authenticate with a certificate you need to provide a --local-user") diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index d1eaf816db1..22491a86e97 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -37,7 +37,7 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] # Create a new process that will wait until the connection is established and then delete keys. cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, delete_cert, cert_file, private_key_file, - log_file, True)) + log_file, True)) cleanup_process.start() command = [_get_ssh_path(), _get_host(username, ip)] @@ -114,8 +114,9 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, validity_warning = "" if validity: validity_warning = f" {validity.lower()}" - logger.warning("%s contains sensitive information%s%s\nPlease delete it once you no longer need this config file. ", - path_to_delete, items_to_delete, validity_warning) + logger.warning("%s contains sensitive information%s%s\n" + "Please delete it once you no longer need this config file. ", + path_to_delete, items_to_delete, validity_warning) lines = [""] diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 0be930fa335..7952b63f3c1 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -17,11 +17,11 @@ class SshCustomCommandTest(unittest.TestCase): @mock.patch('azext_ssh.custom._assert_args') def test_ssh_vm(self, mock_assert, mock_do_op): cmd = mock.Mock() - custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False) + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", None) - mock_assert.assert_called_once_with("rg", "vm", "ip") + mock_assert.assert_called_once_with("rg", "vm", "ip", "cert", "username") mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", False, None, mock.ANY) + cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", None, mock.ANY) @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('azext_ssh.ssh_utils.write_ssh_config') @@ -35,16 +35,16 @@ def test_ssh_config(self, mock_join, mock_dirname, mock_isdir, mock_assert, mock mock_isdir.return_value = True mock_join.side_effect = ['az_ssh_config/rg-vm', 'path/to/az_ssh_config/rg-vm'] - def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, credentials_folder, op_call): - op_call(ssh_ip, "username", "cert_file", private_key_file, False) + def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, local_user, cert_file, credentials_folder, op_call): + op_call(ssh_ip, "username", "cert", private_key_file, False, False) mock_do_op.side_effect = do_op_side_effect - custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False, None) + custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False, "username", "cert", "port", None) - mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "ip", "username", "cert_file", "private", False) - mock_assert.assert_called_once_with("rg", "vm", "ip") + mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "port", "ip", "username", "cert", "private", False, False) + mock_assert.assert_called_once_with("rg", "vm", "ip", "cert", "username") mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", False, 'path/to/az_ssh_config/rg-vm', mock.ANY) + cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", 'path/to/az_ssh_config/rg-vm', mock.ANY) @mock.patch('azext_ssh.ssh_utils.get_ssh_cert_principals') @mock.patch('os.path.join') @@ -53,7 +53,7 @@ def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, pri @mock.patch('azext_ssh.custom._get_modulus_exponent') @mock.patch('azure.cli.core._profile.Profile') @mock.patch('azext_ssh.custom._write_cert_file') - def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock_ip, + def test_do_ssh_op_aad_user(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock_ip, mock_check_files, mock_join, mock_principal): cmd = mock.Mock() cmd.cli_ctx = mock.Mock() @@ -68,14 +68,28 @@ def test_do_ssh_op(self, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock profile.get_msal_token.return_value = "username", "certificate" mock_join.return_value = "public-aadcert.pub" - custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, "cred/folder", mock_op) + custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, None, None, "cred/folder", mock_op) mock_check_files.assert_called_once_with("publicfile", "privatefile", "cred/folder") mock_ip.assert_not_called() mock_get_mod_exp.assert_called_once_with("public") mock_write_cert.assert_called_once_with("certificate", "public-aadcert.pub") mock_op.assert_called_once_with( - "1.2.3.4", "username", "public-aadcert.pub", "private", False) + "1.2.3.4", "username", "public-aadcert.pub", "private", False, True) + + @mock.patch('azext_ssh.custom._check_or_create_public_private_files') + @mock.patch('azext_ssh.ip_utils.get_ssh_ip') + def test_do_ssh_op_local_user(self, mock_ip, mock_check_files): + cmd = mock.Mock() + mock_op = mock.Mock() + mock_ip.return_value = "1.2.3.4" + + custom._do_ssh_op(cmd, "vm", "rg", None, "publicfile", "privatefile", False, "username", "cert", "cred/folder", mock_op) + + mock_check_files.assert_not_called() + mock_ip.assert_called_once_with(cmd, "vm", "rg", False) + mock_op.assert_called_once_with( + "1.2.3.4", "username", "cert", "privatefile", False, False) @mock.patch('azext_ssh.custom._check_or_create_public_private_files') @mock.patch('azext_ssh.ip_utils.get_ssh_ip') @@ -88,21 +102,29 @@ def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_file self.assertRaises( azclierror.ResourceNotFoundError, custom._do_ssh_op, cmd, "rg", "vm", None, - "publicfile", "privatefile", False, "cred/folder", mock_op) + "publicfile", "privatefile", False, None, None, "cred/folder", mock_op) mock_check_files.assert_not_called() mock_ip.assert_called_once_with(cmd, "rg", "vm", False) def test_assert_args_no_ip_or_vm(self): - self.assertRaises(azclierror.RequiredArgumentMissingError, custom._assert_args, None, None, None) + self.assertRaises(azclierror.RequiredArgumentMissingError, custom._assert_args, None, None, None, None, None) def test_assert_args_vm_rg_mismatch(self): - self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", None, None) - self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, "vm", None) + self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", None, None, None, None) + self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, "vm", None, None, None) def test_assert_args_ip_with_vm_or_rg(self): - self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, "vm", "ip") - self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", "vm", "ip") + self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, "vm", "ip", None, None) + self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", "vm", "ip", None, None) + + def test_assert_args_cert_with_no_user(self): + self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, None, "ip", "certificate", None) + + @mock.patch('os.path.isfile') + def test_assert_args_invalid_cert_filepath(self, mock_is_file): + mock_is_file.return_value = False + self.assertRaises(azclierror.FileOperationError, custom._assert_args, 'rg', 'vm', None, 'cert_path', 'username') @mock.patch('azext_ssh.ssh_utils.create_ssh_keyfile') @mock.patch('tempfile.mkdtemp') diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index 1be7eb4481d..9e20e610317 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -25,11 +25,11 @@ def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path, expected_command = ["ssh", "user@ip", "-i", "file", "-o", "option", "-E", "/log/file/path", "-v"] - ssh_utils.start_ssh_connection(None, None, "ip", "user", "cert", "private", True) + ssh_utils.start_ssh_connection("port", None, "ip", "user", "cert", "private", True, True) mock_path.assert_called_once_with() mock_host.assert_called_once_with("user", "ip") - mock_build.assert_called_once_with("cert", "private", None) + mock_build.assert_called_once_with("cert", "private", "port") mock_call.assert_called_once_with(expected_command, shell=platform.system() == 'Windows') @mock.patch.object(ssh_utils, '_get_ssh_path') @@ -41,7 +41,7 @@ def test_start_ssh_connection_with_args(self, mock_call, mock_host, mock_path): expected_command = ["ssh", "user@ip", "-i", "private", "-o", "CertificateFile=cert", "-p", "2222", "--thing", "-vv"] - ssh_utils.start_ssh_connection("2222", ["--thing", "-vv"], "ip", "user", "cert", "private", True) + ssh_utils.start_ssh_connection("2222", ["--thing", "-vv"], "ip", "user", "cert", "private", True, True) mock_path.assert_called_once_with() mock_host.assert_called_once_with("user", "ip") @@ -58,17 +58,19 @@ def test_write_ssh_config_ip_and_vm(self, mock_validity): "\tHostName 1.2.3.4", "\tCertificateFile cert", "\tIdentityFile privatekey", + "\tPort port", "Host 1.2.3.4", "\tUser username", "\tCertificateFile cert", - "\tIdentityFile privatekey" + "\tIdentityFile privatekey", + "\tPort port" ] with mock.patch('builtins.open') as mock_open: mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", True, "1.2.3.4", "username", "cert", "privatekey", False + "path/to/file", "rg", "vm", True, "port", "1.2.3.4", "username", "cert", "privatekey", True, False ) mock_validity.assert_called_once_with("cert") mock_open.assert_called_once_with("path/to/file", "w") @@ -95,14 +97,14 @@ def test_write_ssh_config_append(self, mock_validity): mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", False, "1.2.3.4", "username", "cert", "privatekey", False + "path/to/file", "rg", "vm", False, None, "1.2.3.4", "username", "cert", "privatekey", True, True ) mock_validity.assert_called_once_with("cert") mock_open.assert_called_once_with("path/to/file", "a") mock_file.write.assert_called_once_with('\n'.join(expected_lines)) - + @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') def test_write_ssh_config_ip_only(self, mock_validity): expected_lines = [ @@ -118,10 +120,10 @@ def test_write_ssh_config_ip_only(self, mock_validity): mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", None, None, True, "1.2.3.4", "username", "cert", "privatekey", False + "path/to/file", None, None, True, None, "1.2.3.4", "username", "cert", "privatekey", False, False ) - mock_validity.assert_called_once_with("cert") + mock_validity.assert_not_called() mock_open.assert_called_once_with("path/to/file", "w") mock_file.write.assert_called_once_with('\n'.join(expected_lines)) From 1d2ec22819fad26c6ecc74262a00af16ef574bbd Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 8 Nov 2021 18:46:37 -0500 Subject: [PATCH 23/26] Update HISTORY --- src/ssh/HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 387866484cd..84449613487 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -7,6 +7,7 @@ Release History * Keys generated during ssh config are saved in az_ssh_config folder in the same directory as --file. * Users no longer allowed to run ssh cert with no parameters. * When --public-key-file/-f is not provided to ssh cert, generated public and private keys are saved in the same folder as --file. +* Add support to connect to local users on local machines using key based, cert based, or password based authentication. 0.1.8 ----- From abe3e014fd48f9696b67438823f99828e1bfb78c Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 10 Nov 2021 12:49:33 -0500 Subject: [PATCH 24/26] Fix typo --- src/ssh/azext_ssh/ssh_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index a90319b4392..dd3652758f6 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -56,7 +56,7 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi # Make sure all files have been properly removed. _do_cleanup(delete_keys, cert_file, private_key_file) if log_file: - file_utils.delete_file(log_file, f"Couldn't delete temporary log file {cert_file}. ", True) + file_utils.delete_file(log_file, f"Couldn't delete temporary log file {log_file}. ", True) if delete_keys: temp_dir = os.path.dirname(cert_file) file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) From 120d1956b9fd40d42e13c4d939c76ceb3b7743e3 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 16 Nov 2021 20:26:20 -0500 Subject: [PATCH 25/26] prepend log args --- src/ssh/azext_ssh/ssh_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index c7e37a84b3a..b1ec587f4e2 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -34,7 +34,7 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi log_file_dir = os.path.dirname(cert_file) log_file_name = 'ssh_client_log_' + str(os.getpid()) log_file = os.path.join(log_file_dir, log_file_name) - ssh_arg_list = ssh_arg_list + ['-E', log_file, '-v'] + ssh_arg_list = ['-E', log_file, '-v'] + ssh_arg_list # Create a new process that will wait until the connection is established and then delete keys. cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, delete_cert, cert_file, private_key_file, log_file, True)) @@ -43,6 +43,7 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list + print(command) logger.debug("Running ssh command %s", ' '.join(command)) subprocess.call(command, shell=platform.system() == 'Windows') From 1485f4beea83c0fc5e9bfbb9da6d8b217c48ea01 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 16 Nov 2021 20:29:09 -0500 Subject: [PATCH 26/26] remove print statement --- src/ssh/azext_ssh/ssh_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index b1ec587f4e2..bf27741aa7a 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -43,7 +43,6 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi command = [_get_ssh_path(), _get_host(username, ip)] command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list - print(command) logger.debug("Running ssh command %s", ' '.join(command)) subprocess.call(command, shell=platform.system() == 'Windows')