Skip to content

Commit

Permalink
Adding support for IP Connect and partial for Quick connect. (#5800)
Browse files Browse the repository at this point in the history
* Initial changes for ipconnect and quickconnect

* adding endpoints for quickconnect/developer sku

* stlying fixes

* adding constant file with enum for skus

* Documentation and final changes

* Documentation and final changes

* Documentation and final changes
  • Loading branch information
aavalang authored Jan 27, 2023
1 parent fed2786 commit d3bc6dc
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 38 deletions.
6 changes: 6 additions & 0 deletions src/bastion/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Release History
===============

0.2.0
++++++
* Adding support for IP connect through AZ CLI.
* Initial support for connectivity through developerSku.
* Bug fixes.

0.1.0
++++++
* Initial release.
10 changes: 10 additions & 0 deletions src/bastion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ az network bastion show --name MyBastionHost --resource-group MyResourceGroup
```commandline
az network bastion update --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling
```

### RDP to VM/VMSS using Azure Bastion host machine
```commandline
az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id ResourceId
```

### SSH to VM/VMSS using Azure Bastion host machine
```commandline
az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling --target-resource-id ResourceId --auth-type password
```
16 changes: 16 additions & 0 deletions src/bastion/azext_bastion/BastionServiceConstants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

# pylint: disable=import-error,unused-import

from enum import Enum


class BastionSku(Enum):

Basic = "Basic"
Standard = "Standard"
Developer = "Developer"
QuickConnect = "QuickConnect"
11 changes: 10 additions & 1 deletion src/bastion/azext_bastion/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
- name: SSH to virtual machine using Azure Bastion using AAD.
text: |
az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD
- name: SSH to virtual machine using Azure Bastion using AAD.
text: |
az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD
"""

helps['network bastion rdp'] = """
Expand All @@ -33,13 +36,19 @@
- name: RDP to virtual machine using Azure Bastion.
text: |
az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId
- name: RDP to machine using reachable IP address.
text: |
az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1
"""

helps['network bastion tunnel'] = """
type: command
short-summary: Open a tunnel through Azure Bastion to a target virtual machine.
examples:
- name: Open a tunnel through Azure Bastion to a target virtual machine.
- name: Open a tunnel through Azure Bastion to a target virtual machine using resourceId.
text: |
az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --resource-port 22 --port 50022
- name: Open a tunnel through Azure Bastion to a target virtual machine using its IP address.
text: |
az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 --resource-port 22 --port 50022
"""
5 changes: 4 additions & 1 deletion src/bastion/azext_bastion/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from azure.cli.core.commands.parameters import get_resource_name_completion_list, get_three_state_flag
from knack.arguments import CLIArgumentType
from ._validators import (validate_ip_address)


def load_arguments(self, _): # pylint: disable=unused-argument
Expand All @@ -24,8 +25,10 @@ def load_arguments(self, _): # pylint: disable=unused-argument
c.argument("bastion_host_name", bastion_host_name_type, options_list=["--name", "-n"])
c.argument("resource_port", help="Resource port of the target VM to which the bastion will connect.",
options_list=["--resource-port"])
c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.",
c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", required=False,
options_list=["--target-resource-id"])
c.argument("target_ip_address", help="IP address of target Virtual Machine.", required=False,
options_list=["--target-ip-address"], validator=validate_ip_address)

with self.argument_context("network bastion ssh") as c:
c.argument("auth_type", help="Auth type to use for SSH connections.", options_list=["--auth-type"])
Expand Down
28 changes: 28 additions & 0 deletions src/bastion/azext_bastion/_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import ipaddress
from azure.cli.core.azclierror import InvalidArgumentValueError


def validate_ip_address(namespace):
if namespace.target_ip_address is not None:
_validate_ip_address_format(namespace)


def _validate_ip_address_format(namespace):
if namespace.target_ip_address is not None:
input_value = namespace.target_ip_address
if ' ' in input_value:
raise InvalidArgumentValueError("Spaces not allowed: '{}' ".format(input_value))
input_ips = input_value.split(',')
if len(input_ips) > 8:
raise InvalidArgumentValueError('Maximum 8 IP addresses are allowed per rule.')
validated_ips = ''
for ip in input_ips:
# Use ipaddress library to validate ip network format
ip_obj = ipaddress.ip_network(ip)
validated_ips += str(ip_obj) + ','
namespace.target_ip_address = validated_ips[:-1]
105 changes: 73 additions & 32 deletions src/bastion/azext_bastion/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import requests
from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \
UnrecognizedArgumentError, CLIInternalError, ClientRequestError
from azure.cli.core.commands.client_factory import get_subscription_id
from knack.log import get_logger
from msrestazure.tools import is_valid_resource_id

from .BastionServiceConstants import BastionSku
from .aaz.latest.network.bastion import Create as _BastionCreate


Expand Down Expand Up @@ -132,20 +133,27 @@ def _build_args(cert_file, private_key_file):
return private_key + certificate


def ssh_bastion_host(cmd, auth_type, target_resource_id, resource_group_name, bastion_host_name,
def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, resource_group_name, bastion_host_name,
resource_port=None, username=None, ssh_key=None):
import os
from .aaz.latest.network.bastion import Show

_test_extension(SSH_EXTENSION_NAME)
bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": bastion_host_name
})

if not resource_port:
resource_port = 22
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
raise InvalidArgumentValueError(err_msg)

tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port)
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

_validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port)
t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand Down Expand Up @@ -208,32 +216,33 @@ def _get_rdp_path(rdp_command="mstsc"):
return rdp_path


def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_name,
def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name,
resource_port=None, disable_gateway=False, configure=False, enable_mfa=False):
import os
from azure.cli.core._profile import Profile
from ._process_helper import launch_and_wait

if not resource_port:
resource_port = 3389
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
raise InvalidArgumentValueError(err_msg)

from .aaz.latest.network.bastion import Show

bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": bastion_host_name
})

if bastion['sku']['name'] == "Basic" or \
bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True:
if not resource_port:
resource_port = 3389

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
_validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

if platform.system() == "Windows":
if disable_gateway:
tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port)
if disable_gateway or ip_connect:
tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, bastion_endpoint)
if ip_connect:
tunnel_server.set_host_name(target_ip_address)
t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand All @@ -244,9 +253,8 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_
profile = Profile(cli_ctx=cmd.cli_ctx)
access_token = profile.get_raw_token()[0][2].get("accessToken")
logger.debug("Response %s", access_token)
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"

web_address = f"https://{bastion['dnsName']}/api/rdpfile?resourceId={target_resource_id}" \
f"&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "*/*",
Expand All @@ -259,7 +267,6 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_
raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.")

_write_to_file(response)

rdpfilepath = os.getcwd() + "/conn.rdp"
command = [_get_rdp_path()]
if configure:
Expand All @@ -270,24 +277,46 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_
raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows")


def _is_ipconnect_request(cmd, bastion, target_ip_address):
if bastion['enableIpConnect'] is True and target_ip_address:
return True

return False


def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address):
if target_ip_address:
if bastion['enableIpConnect'] is not True:
raise InvalidArgumentValueError("Bastion does not have IP Connect feature enabled, please enable and try again")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
elif not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
raise InvalidArgumentValueError(err_msg)


def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id):
if bastion['sku']['name'] == BastionSku.QuickConnect.value or bastion['sku']['name'] == BastionSku.Developer.value:
from .developer_sku_helper import (_get_data_pod)
bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion)
return bastion_endpoint

return bastion['dnsName']


def _write_to_file(response):
with open("conn.rdp", "w", encoding="utf-8") as f:
for line in response.text.splitlines():
if not line.startswith('signscope'):
f.write(line + "\n")


def _get_tunnel(cmd, resource_group_name, name, vm_id, resource_port, port=None):
def _get_tunnel(cmd, bastion, bastion_endpoint, vm_id, resource_port, port=None):
from .tunnel import TunnelServer
from .aaz.latest.network.bastion import Show

bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": name
})
if port is None:
port = 0 # will auto-select a free port from 1024-65535
tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, vm_id, resource_port)
tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, bastion_endpoint, vm_id, resource_port)

return tunnel_server

Expand All @@ -303,12 +332,24 @@ def _tunnel_close_handler(tunnel):
sys.exit()


def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_host_name, resource_port, port,
def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port, port,
timeout=None):
if not is_valid_resource_id(target_resource_id):
raise InvalidArgumentValueError("Please enter a valid VM resource ID.")

tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port, port)
from .aaz.latest.network.bastion import Show
bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": bastion_host_name
})

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

_validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port)
t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand Down
29 changes: 29 additions & 0 deletions src/bastion/azext_bastion/developer_sku_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

# pylint: disable=import-error,unused-import


def _get_data_pod(cmd, resource_port, target_resource_id, bastion):
from azure.cli.core._profile import Profile
from azure.cli.core.util import should_disable_connection_verify
import requests

profile = Profile(cli_ctx=cmd.cli_ctx)
auth_token, _, _ = profile.get_raw_token()
content = {
'resourceId': target_resource_id,
'bastionResourceId': bastion.id,
'vmPort': resource_port,
'azToken': auth_token[1],
'connectionType': 'nativeclient'
}
headers = {'Content-Type': 'application/json'}

web_address = f"https://{bastion['dnsName']}/api/connection"
response = requests.post(web_address, json=content, headers=headers,
verify=(not should_disable_connection_verify()))

return response.content.decode("utf-8")
Loading

0 comments on commit d3bc6dc

Please sign in to comment.