Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SSH] Revert changes to _check_or_create_public_private_files #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None):
keys_folder = os.path.dirname(cert_path)
logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate "
"is no longer being used.", keys_folder)
public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder)
public_key_file, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder)
cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path)
print(cert_file + "\n")

Expand All @@ -87,9 +87,11 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke
# If user provides a local user, use the provided credentials for authentication
if not username:
delete_cert = True
public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file,
private_key_file,
credentials_folder)
if not public_key_file and not private_key_file:
delete_keys = True
public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file,
private_key_file,
credentials_folder)
cert_file, username = _get_and_write_certificate(cmd, public_key_file, None)

op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert)
Expand Down Expand Up @@ -172,13 +174,9 @@ def _assert_args(resource_group, vm_name, ssh_ip, cert_file, username):
raise azclierror.FileOperationError(f"Certificate file {cert_file} not found")


def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder):
delete_keys = False
def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder=None):
# If nothing is passed, then create a directory with a ephemeral keypair
if not public_key_file and not private_key_file:
# We only want to delete the keys if the user hasn't provided their own keys
# Only ssh vm deletes generated keys.
delete_keys = True
if not credentials_folder:
# az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails.
credentials_folder = tempfile.mkdtemp(prefix="aadsshcert")
Expand Down Expand Up @@ -206,7 +204,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre
if not os.path.isfile(private_key_file):
raise azclierror.FileOperationError(f"Private key file {private_key_file} not found")

return public_key_file, private_key_file, delete_keys
return public_key_file, private_key_file


def _write_cert_file(certificate_contents, cert_file):
Expand Down